Unverified Commit 20c55e46 authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

Check for backend support in Jax context parallel fused attention test (#1227)



Update test to check support for context parallel attention.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent 86f07be4
......@@ -124,8 +124,10 @@ class TestDistributedSelfAttn:
seqlen,
seqlen,
hidden,
None, # no window
False, # not context parallel
):
pytest.skip(f"No FusedAttn backwend found")
pytest.skip(f"No FusedAttn backend found")
def target_func(qkv, bias, mask):
return jnp.mean(
......@@ -257,8 +259,10 @@ class TestDistributedCrossAttn:
seqlen,
seqlen,
hidden,
None, # no window
False, # not context parallel
):
pytest.skip(f"No FusedAttn backwend found")
pytest.skip(f"No FusedAttn backend found")
def target_func(q, kv, mask):
return jnp.mean(
......@@ -403,7 +407,24 @@ class TestDistributedContexParallelSelfAttn:
_, seqlen, num_head, hidden = data_shape
num_kv_heads = num_head // kv_groups
# make sure the mesh evently divides cp and tp axis
if not is_fused_attn_kernel_available(
dtype,
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_kv_heads,
seqlen,
seqlen,
hidden,
None, # no window
cp_size > 1,
):
pytest.skip(f"No FusedAttn backend found")
# make sure the mesh even divides cp and tp axis
if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")
......
......@@ -190,24 +190,37 @@ def is_fused_attn_kernel_available(
kv_max_seqlen,
head_dim,
window_size: Optional[Tuple[int, int]] = None,
is_context_parallel: bool = False,
):
"""
To check whether the fused attention kernel is supported
"""
return tex.FusedAttnHelper(
q_dtype,
kv_dtype,
qkv_layout.value,
attn_bias_type.value,
attn_mask_type.value,
dropout_probability,
q_num_heads,
kv_num_heads,
q_max_seqlen,
kv_max_seqlen,
head_dim,
(-1, -1) if window_size is None else window_size,
).is_fused_attn_kernel_available()
def make_helper(attn_mask_type):
return tex.FusedAttnHelper(
q_dtype,
kv_dtype,
qkv_layout.value,
attn_bias_type.value,
attn_mask_type.value,
dropout_probability,
q_num_heads,
kv_num_heads,
q_max_seqlen,
kv_max_seqlen,
head_dim,
(-1, -1) if window_size is None else window_size,
)
if not make_helper(attn_mask_type).is_fused_attn_kernel_available():
return False
# For context parallel need to check additional masking types
if is_context_parallel and attn_mask_type == AttnMaskType.CAUSAL_MASK:
if not make_helper(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK).is_fused_attn_kernel_available():
return False
return True
def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
......
......@@ -923,26 +923,30 @@ class _FusedAttnCPWithAllGatherHelper:
header = "Context parallel fused attention"
allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD]
assert self.config.qkv_layout in allowed_layouts, (
f"{header} only supports layouts: {','.join([str(x) for x in allowed_layouts])} got:"
f" {self.config.qkv_layout}"
)
if self.config.qkv_layout not in allowed_layouts:
raise ValueError(
f"{header} only supports layouts:"
f" {','.join([str(x) for x in allowed_layouts])} got: {self.config.qkv_layout}"
)
assert (
self.config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS
), f"{header} does not support bias got: {self.config.attn_bias_type}"
if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS:
raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")
allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK]
assert self.config.attn_mask_type in allowed_masks, (
f"{header} only supports masking types: "
f" {','.join([str(x) for x in allowed_masks])} got: {self.config.attn_mask_type}"
)
if self.config.attn_mask_type not in allowed_masks:
raise ValueError(
f"{header} only supports masking types: "
f" {','.join([str(x) for x in allowed_masks])} got: {self.config.attn_mask_type}"
)
assert self.config.max_segments_per_seq == 1, (
f"{header} only supports max_segments_per_seq == 1 got:"
f" {self.config.max_segments_per_seq}"
)
assert self.config.dropout_probability == 0.0, f"{header} does not support dropout"
if self.config.max_segments_per_seq != 1:
raise ValueError(
f"{header} only supports max_segments_per_seq == 1 got:"
f" {self.config.max_segments_per_seq}"
)
if self.config.dropout_probability != 0.0:
raise ValueError(f"{header} does not support dropout")
def get_adjusted_mask(self):
"""Converts the mask for context parallelism."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment