"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "258d084237dccef6d862d20eb2fd63c77315cb36"
Unverified Commit 20c55e46 authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

Check for backend support in Jax context parallel fused attention test (#1227)



Update test to check support for context parallel attention.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent 86f07be4
...@@ -124,8 +124,10 @@ class TestDistributedSelfAttn: ...@@ -124,8 +124,10 @@ class TestDistributedSelfAttn:
seqlen, seqlen,
seqlen, seqlen,
hidden, hidden,
None, # no window
False, # not context parallel
): ):
pytest.skip(f"No FusedAttn backwend found") pytest.skip(f"No FusedAttn backend found")
def target_func(qkv, bias, mask): def target_func(qkv, bias, mask):
return jnp.mean( return jnp.mean(
...@@ -257,8 +259,10 @@ class TestDistributedCrossAttn: ...@@ -257,8 +259,10 @@ class TestDistributedCrossAttn:
seqlen, seqlen,
seqlen, seqlen,
hidden, hidden,
None, # no window
False, # not context parallel
): ):
pytest.skip(f"No FusedAttn backwend found") pytest.skip(f"No FusedAttn backend found")
def target_func(q, kv, mask): def target_func(q, kv, mask):
return jnp.mean( return jnp.mean(
...@@ -403,7 +407,24 @@ class TestDistributedContexParallelSelfAttn: ...@@ -403,7 +407,24 @@ class TestDistributedContexParallelSelfAttn:
_, seqlen, num_head, hidden = data_shape _, seqlen, num_head, hidden = data_shape
num_kv_heads = num_head // kv_groups num_kv_heads = num_head // kv_groups
# make sure the mesh evently divides cp and tp axis if not is_fused_attn_kernel_available(
dtype,
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_kv_heads,
seqlen,
seqlen,
hidden,
None, # no window
cp_size > 1,
):
pytest.skip(f"No FusedAttn backend found")
# make sure the mesh even divides cp and tp axis
if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0: if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")
......
...@@ -190,10 +190,13 @@ def is_fused_attn_kernel_available( ...@@ -190,10 +190,13 @@ def is_fused_attn_kernel_available(
kv_max_seqlen, kv_max_seqlen,
head_dim, head_dim,
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
is_context_parallel: bool = False,
): ):
""" """
To check whether the fused attention kernel is supported To check whether the fused attention kernel is supported
""" """
def make_helper(attn_mask_type):
return tex.FusedAttnHelper( return tex.FusedAttnHelper(
q_dtype, q_dtype,
kv_dtype, kv_dtype,
...@@ -207,7 +210,17 @@ def is_fused_attn_kernel_available( ...@@ -207,7 +210,17 @@ def is_fused_attn_kernel_available(
kv_max_seqlen, kv_max_seqlen,
head_dim, head_dim,
(-1, -1) if window_size is None else window_size, (-1, -1) if window_size is None else window_size,
).is_fused_attn_kernel_available() )
if not make_helper(attn_mask_type).is_fused_attn_kernel_available():
return False
# For context parallel need to check additional masking types
if is_context_parallel and attn_mask_type == AttnMaskType.CAUSAL_MASK:
if not make_helper(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK).is_fused_attn_kernel_available():
return False
return True
def _obtain_batch_and_max_seqlen(qkv, qkv_layout): def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
......
...@@ -923,26 +923,30 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -923,26 +923,30 @@ class _FusedAttnCPWithAllGatherHelper:
header = "Context parallel fused attention" header = "Context parallel fused attention"
allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD] allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD]
assert self.config.qkv_layout in allowed_layouts, ( if self.config.qkv_layout not in allowed_layouts:
f"{header} only supports layouts: {','.join([str(x) for x in allowed_layouts])} got:" raise ValueError(
f" {self.config.qkv_layout}" f"{header} only supports layouts:"
f" {','.join([str(x) for x in allowed_layouts])} got: {self.config.qkv_layout}"
) )
assert ( if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS:
self.config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")
), f"{header} does not support bias got: {self.config.attn_bias_type}"
allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK] allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK]
assert self.config.attn_mask_type in allowed_masks, ( if self.config.attn_mask_type not in allowed_masks:
raise ValueError(
f"{header} only supports masking types: " f"{header} only supports masking types: "
f" {','.join([str(x) for x in allowed_masks])} got: {self.config.attn_mask_type}" f" {','.join([str(x) for x in allowed_masks])} got: {self.config.attn_mask_type}"
) )
assert self.config.max_segments_per_seq == 1, ( if self.config.max_segments_per_seq != 1:
raise ValueError(
f"{header} only supports max_segments_per_seq == 1 got:" f"{header} only supports max_segments_per_seq == 1 got:"
f" {self.config.max_segments_per_seq}" f" {self.config.max_segments_per_seq}"
) )
assert self.config.dropout_probability == 0.0, f"{header} does not support dropout"
if self.config.dropout_probability != 0.0:
raise ValueError(f"{header} does not support dropout")
def get_adjusted_mask(self): def get_adjusted_mask(self):
"""Converts the mask for context parallelism.""" """Converts the mask for context parallelism."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment