Unverified Commit d7256866 authored by Md Fahim Faysal Khan's avatar Md Fahim Faysal Khan Committed by GitHub
Browse files

[JAX] Expose context parallel params to jax DPA api (#1292)



Exposed context parallel params to DPA api
Signed-off-by: default avatarMd Fahim Faysal Khan <mdfahimfaysa@nvidia.com>
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>

---------
Signed-off-by: default avatarMd Fahim Faysal Khan <mdfahimfaysa@nvidia.com>
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
Co-authored-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent c42beef4
......@@ -133,7 +133,6 @@ class TestDistributedSelfAttn:
seqlen,
hidden,
None, # no window
False, # not context parallel
):
pytest.skip(f"No FusedAttn backend found")
......@@ -268,7 +267,6 @@ class TestDistributedCrossAttn:
seqlen,
hidden,
None, # no window
False, # not context parallel
):
pytest.skip(f"No FusedAttn backend found")
......@@ -425,22 +423,32 @@ class TestDistributedContextParallelSelfAttn:
num_kv_heads = num_head // kv_groups
scaling_factor = 1.0 / np.sqrt(num_head)
if not is_fused_attn_kernel_available(
dtype,
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_kv_heads,
seqlen,
seqlen,
hidden,
None, # no window
cp_size > 1,
):
pytest.skip(f"No FusedAttn backend found")
def check_has_backend_for_mask(mask_type):
return is_fused_attn_kernel_available(
dtype,
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_kv_heads,
seqlen,
seqlen,
hidden,
None,
) # no SWA for CP
# For causal masking we depend on having bottom right support also.
# The API does not check this and instead we rely on lower level checks to raise
# and exception if the step backend is not supported. This was a deliberate API
# decision to keep the CP size or flag out of the function.
has_backend = check_has_backend_for_mask(attn_mask_type)
if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK:
has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK)
if not has_backend:
pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.")
if dp_size > 1 and batch % dp_size != 0:
pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")
......
......@@ -190,7 +190,6 @@ def is_fused_attn_kernel_available(
kv_max_seqlen,
head_dim,
window_size: Optional[Tuple[int, int]] = None,
is_context_parallel: bool = False,
):
"""
To check whether the fused attention kernel is supported
......@@ -215,11 +214,6 @@ def is_fused_attn_kernel_available(
if not make_helper(attn_mask_type).is_fused_attn_kernel_available():
return False
# For context parallel need to check additional masking types
if is_context_parallel and attn_mask_type == AttnMaskType.CAUSAL_MASK:
if not make_helper(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK).is_fused_attn_kernel_available():
return False
return True
......
......@@ -262,6 +262,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = False
window_size: Optional[Tuple[int, int]] = None
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""
@nn.compact
def __call__(
......@@ -308,6 +310,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
dropout_probability=self.attention_dropout,
is_training=not deterministic,
window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)
elif self.qkv_layout == QKVLayout.BSHD_BS2HD:
"""kvpacked format, treat
......@@ -331,6 +335,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
dropout_probability=self.attention_dropout,
is_training=not deterministic,
window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)
elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
if self.transpose_batch_sequence:
......@@ -349,6 +355,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
dropout_probability=self.attention_dropout,
is_training=not deterministic,
window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)
else:
raise ValueError(f"Unsupported {self.qkv_layout=}.")
......@@ -463,6 +471,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
window_size: Optional[Tuple[int, int]], default = None
Sliding window size. The default value is no sliding window.
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
Optimization parameters
-----------------------
......@@ -483,6 +494,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
window_size: Optional[Tuple[int, int]] = None
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""
@nn.compact
def __call__(
......@@ -614,6 +627,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
transpose_batch_sequence=self.transpose_batch_sequence,
qkv_layout=qkv_layout,
window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment