Unverified Commit d7256866 authored by Md Fahim Faysal Khan's avatar Md Fahim Faysal Khan Committed by GitHub
Browse files

[JAX] Expose context parallel params to jax DPA api (#1292)



Exposed context parallel params to DPA api
Signed-off-by: default avatarMd Fahim Faysal Khan <mdfahimfaysa@nvidia.com>
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>

---------
Signed-off-by: default avatarMd Fahim Faysal Khan <mdfahimfaysa@nvidia.com>
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
Co-authored-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent c42beef4
...@@ -133,7 +133,6 @@ class TestDistributedSelfAttn: ...@@ -133,7 +133,6 @@ class TestDistributedSelfAttn:
seqlen, seqlen,
hidden, hidden,
None, # no window None, # no window
False, # not context parallel
): ):
pytest.skip(f"No FusedAttn backend found") pytest.skip(f"No FusedAttn backend found")
...@@ -268,7 +267,6 @@ class TestDistributedCrossAttn: ...@@ -268,7 +267,6 @@ class TestDistributedCrossAttn:
seqlen, seqlen,
hidden, hidden,
None, # no window None, # no window
False, # not context parallel
): ):
pytest.skip(f"No FusedAttn backend found") pytest.skip(f"No FusedAttn backend found")
...@@ -425,7 +423,8 @@ class TestDistributedContextParallelSelfAttn: ...@@ -425,7 +423,8 @@ class TestDistributedContextParallelSelfAttn:
num_kv_heads = num_head // kv_groups num_kv_heads = num_head // kv_groups
scaling_factor = 1.0 / np.sqrt(num_head) scaling_factor = 1.0 / np.sqrt(num_head)
if not is_fused_attn_kernel_available( def check_has_backend_for_mask(mask_type):
return is_fused_attn_kernel_available(
dtype, dtype,
dtype, dtype,
qkv_layout, qkv_layout,
...@@ -437,10 +436,19 @@ class TestDistributedContextParallelSelfAttn: ...@@ -437,10 +436,19 @@ class TestDistributedContextParallelSelfAttn:
seqlen, seqlen,
seqlen, seqlen,
hidden, hidden,
None, # no window None,
cp_size > 1, ) # no SWA for CP
):
pytest.skip(f"No FusedAttn backend found") # For causal masking we depend on having bottom right support also.
# The API does not check this and instead we rely on lower level checks to raise
# and exception if the step backend is not supported. This was a deliberate API
# decision to keep the CP size or flag out of the function.
has_backend = check_has_backend_for_mask(attn_mask_type)
if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK:
has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK)
if not has_backend:
pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.")
if dp_size > 1 and batch % dp_size != 0: if dp_size > 1 and batch % dp_size != 0:
pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}") pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")
......
...@@ -190,7 +190,6 @@ def is_fused_attn_kernel_available( ...@@ -190,7 +190,6 @@ def is_fused_attn_kernel_available(
kv_max_seqlen, kv_max_seqlen,
head_dim, head_dim,
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
is_context_parallel: bool = False,
): ):
""" """
To check whether the fused attention kernel is supported To check whether the fused attention kernel is supported
...@@ -215,11 +214,6 @@ def is_fused_attn_kernel_available( ...@@ -215,11 +214,6 @@ def is_fused_attn_kernel_available(
if not make_helper(attn_mask_type).is_fused_attn_kernel_available(): if not make_helper(attn_mask_type).is_fused_attn_kernel_available():
return False return False
# For context parallel need to check additional masking types
if is_context_parallel and attn_mask_type == AttnMaskType.CAUSAL_MASK:
if not make_helper(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK).is_fused_attn_kernel_available():
return False
return True return True
......
...@@ -262,6 +262,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -262,6 +262,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
window_size: Optional[Tuple[int, int]] = None window_size: Optional[Tuple[int, int]] = None
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""
@nn.compact @nn.compact
def __call__( def __call__(
...@@ -308,6 +310,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -308,6 +310,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
window_size=self.window_size, window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
) )
elif self.qkv_layout == QKVLayout.BSHD_BS2HD: elif self.qkv_layout == QKVLayout.BSHD_BS2HD:
"""kvpacked format, treat """kvpacked format, treat
...@@ -331,6 +335,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -331,6 +335,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
window_size=self.window_size, window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
) )
elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD: elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
...@@ -349,6 +355,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -349,6 +355,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
window_size=self.window_size, window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
) )
else: else:
raise ValueError(f"Unsupported {self.qkv_layout=}.") raise ValueError(f"Unsupported {self.qkv_layout=}.")
...@@ -463,6 +471,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -463,6 +471,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...). should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
window_size: Optional[Tuple[int, int]], default = None window_size: Optional[Tuple[int, int]], default = None
Sliding window size. The default value is no sliding window. Sliding window size. The default value is no sliding window.
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -483,6 +494,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -483,6 +494,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
window_size: Optional[Tuple[int, int]] = None window_size: Optional[Tuple[int, int]] = None
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""
@nn.compact @nn.compact
def __call__( def __call__(
...@@ -614,6 +627,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -614,6 +627,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
window_size=self.window_size, window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic) )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment