Unverified Commit 20c55e46 authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

Check for backend support in Jax context parallel fused attention test (#1227)



Update test to check support for context parallel attention.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent 86f07be4
...@@ -124,8 +124,10 @@ class TestDistributedSelfAttn: ...@@ -124,8 +124,10 @@ class TestDistributedSelfAttn:
seqlen, seqlen,
seqlen, seqlen,
hidden, hidden,
None, # no window
False, # not context parallel
): ):
pytest.skip(f"No FusedAttn backwend found") pytest.skip(f"No FusedAttn backend found")
def target_func(qkv, bias, mask): def target_func(qkv, bias, mask):
return jnp.mean( return jnp.mean(
...@@ -257,8 +259,10 @@ class TestDistributedCrossAttn: ...@@ -257,8 +259,10 @@ class TestDistributedCrossAttn:
seqlen, seqlen,
seqlen, seqlen,
hidden, hidden,
None, # no window
False, # not context parallel
): ):
pytest.skip(f"No FusedAttn backwend found") pytest.skip(f"No FusedAttn backend found")
def target_func(q, kv, mask): def target_func(q, kv, mask):
return jnp.mean( return jnp.mean(
...@@ -403,7 +407,24 @@ class TestDistributedContexParallelSelfAttn: ...@@ -403,7 +407,24 @@ class TestDistributedContexParallelSelfAttn:
_, seqlen, num_head, hidden = data_shape _, seqlen, num_head, hidden = data_shape
num_kv_heads = num_head // kv_groups num_kv_heads = num_head // kv_groups
# make sure the mesh evently divides cp and tp axis if not is_fused_attn_kernel_available(
dtype,
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_kv_heads,
seqlen,
seqlen,
hidden,
None, # no window
cp_size > 1,
):
pytest.skip(f"No FusedAttn backend found")
# make sure the mesh even divides cp and tp axis
if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0: if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")
......
...@@ -190,24 +190,37 @@ def is_fused_attn_kernel_available( ...@@ -190,24 +190,37 @@ def is_fused_attn_kernel_available(
kv_max_seqlen, kv_max_seqlen,
head_dim, head_dim,
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
is_context_parallel: bool = False,
): ):
""" """
To check whether the fused attention kernel is supported To check whether the fused attention kernel is supported
""" """
return tex.FusedAttnHelper(
q_dtype, def make_helper(attn_mask_type):
kv_dtype, return tex.FusedAttnHelper(
qkv_layout.value, q_dtype,
attn_bias_type.value, kv_dtype,
attn_mask_type.value, qkv_layout.value,
dropout_probability, attn_bias_type.value,
q_num_heads, attn_mask_type.value,
kv_num_heads, dropout_probability,
q_max_seqlen, q_num_heads,
kv_max_seqlen, kv_num_heads,
head_dim, q_max_seqlen,
(-1, -1) if window_size is None else window_size, kv_max_seqlen,
).is_fused_attn_kernel_available() head_dim,
(-1, -1) if window_size is None else window_size,
)
if not make_helper(attn_mask_type).is_fused_attn_kernel_available():
return False
# For context parallel need to check additional masking types
if is_context_parallel and attn_mask_type == AttnMaskType.CAUSAL_MASK:
if not make_helper(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK).is_fused_attn_kernel_available():
return False
return True
def _obtain_batch_and_max_seqlen(qkv, qkv_layout): def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
......
...@@ -923,26 +923,30 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -923,26 +923,30 @@ class _FusedAttnCPWithAllGatherHelper:
header = "Context parallel fused attention" header = "Context parallel fused attention"
allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD] allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD]
assert self.config.qkv_layout in allowed_layouts, ( if self.config.qkv_layout not in allowed_layouts:
f"{header} only supports layouts: {','.join([str(x) for x in allowed_layouts])} got:" raise ValueError(
f" {self.config.qkv_layout}" f"{header} only supports layouts:"
) f" {','.join([str(x) for x in allowed_layouts])} got: {self.config.qkv_layout}"
)
assert ( if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS:
self.config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")
), f"{header} does not support bias got: {self.config.attn_bias_type}"
allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK] allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK]
assert self.config.attn_mask_type in allowed_masks, ( if self.config.attn_mask_type not in allowed_masks:
f"{header} only supports masking types: " raise ValueError(
f" {','.join([str(x) for x in allowed_masks])} got: {self.config.attn_mask_type}" f"{header} only supports masking types: "
) f" {','.join([str(x) for x in allowed_masks])} got: {self.config.attn_mask_type}"
)
assert self.config.max_segments_per_seq == 1, ( if self.config.max_segments_per_seq != 1:
f"{header} only supports max_segments_per_seq == 1 got:" raise ValueError(
f" {self.config.max_segments_per_seq}" f"{header} only supports max_segments_per_seq == 1 got:"
) f" {self.config.max_segments_per_seq}"
assert self.config.dropout_probability == 0.0, f"{header} does not support dropout" )
if self.config.dropout_probability != 0.0:
raise ValueError(f"{header} does not support dropout")
def get_adjusted_mask(self): def get_adjusted_mask(self):
"""Converts the mask for context parallelism.""" """Converts the mask for context parallelism."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment