"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "1c702b4cff6fcee1b92857c283e7c3eb20534923"
Unverified Commit 20c55e46 authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

Check for backend support in Jax context parallel fused attention test (#1227)



Update test to check support for context parallel attention.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent 86f07be4
......@@ -124,8 +124,10 @@ class TestDistributedSelfAttn:
seqlen,
seqlen,
hidden,
None, # no window
False, # not context parallel
):
pytest.skip(f"No FusedAttn backwend found")
pytest.skip(f"No FusedAttn backend found")
def target_func(qkv, bias, mask):
return jnp.mean(
......@@ -257,8 +259,10 @@ class TestDistributedCrossAttn:
seqlen,
seqlen,
hidden,
None, # no window
False, # not context parallel
):
pytest.skip(f"No FusedAttn backwend found")
pytest.skip(f"No FusedAttn backend found")
def target_func(q, kv, mask):
return jnp.mean(
......@@ -403,7 +407,24 @@ class TestDistributedContexParallelSelfAttn:
_, seqlen, num_head, hidden = data_shape
num_kv_heads = num_head // kv_groups
# make sure the mesh evently divides cp and tp axis
if not is_fused_attn_kernel_available(
dtype,
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_kv_heads,
seqlen,
seqlen,
hidden,
None, # no window
cp_size > 1,
):
pytest.skip(f"No FusedAttn backend found")
# make sure the mesh even divides cp and tp axis
if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")
......
......@@ -190,10 +190,13 @@ def is_fused_attn_kernel_available(
kv_max_seqlen,
head_dim,
window_size: Optional[Tuple[int, int]] = None,
is_context_parallel: bool = False,
):
"""
To check whether the fused attention kernel is supported
"""
def make_helper(attn_mask_type):
return tex.FusedAttnHelper(
q_dtype,
kv_dtype,
......@@ -207,7 +210,17 @@ def is_fused_attn_kernel_available(
kv_max_seqlen,
head_dim,
(-1, -1) if window_size is None else window_size,
).is_fused_attn_kernel_available()
)
if not make_helper(attn_mask_type).is_fused_attn_kernel_available():
return False
# For context parallel need to check additional masking types
if is_context_parallel and attn_mask_type == AttnMaskType.CAUSAL_MASK:
if not make_helper(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK).is_fused_attn_kernel_available():
return False
return True
def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
......
......@@ -923,26 +923,30 @@ class _FusedAttnCPWithAllGatherHelper:
header = "Context parallel fused attention"
allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD]
assert self.config.qkv_layout in allowed_layouts, (
f"{header} only supports layouts: {','.join([str(x) for x in allowed_layouts])} got:"
f" {self.config.qkv_layout}"
if self.config.qkv_layout not in allowed_layouts:
raise ValueError(
f"{header} only supports layouts:"
f" {','.join([str(x) for x in allowed_layouts])} got: {self.config.qkv_layout}"
)
assert (
self.config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS
), f"{header} does not support bias got: {self.config.attn_bias_type}"
if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS:
raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")
allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK]
assert self.config.attn_mask_type in allowed_masks, (
if self.config.attn_mask_type not in allowed_masks:
raise ValueError(
f"{header} only supports masking types: "
f" {','.join([str(x) for x in allowed_masks])} got: {self.config.attn_mask_type}"
)
assert self.config.max_segments_per_seq == 1, (
if self.config.max_segments_per_seq != 1:
raise ValueError(
f"{header} only supports max_segments_per_seq == 1 got:"
f" {self.config.max_segments_per_seq}"
)
assert self.config.dropout_probability == 0.0, f"{header} does not support dropout"
if self.config.dropout_probability != 0.0:
raise ValueError(f"{header} does not support dropout")
def get_adjusted_mask(self):
"""Converts the mask for context parallelism."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment