Unverified Commit df699655 authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

[JAX] Fix unit tests to work around cuDNN 9.4 regression of 0 length sequences (#1179)



Modify unit tests to work around cuDNN 9.4 regression.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent c55007b8
......@@ -29,7 +29,10 @@ from transformer_engine.jax.attention import (
get_qkv_format,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
from transformer_engine.transformer_engine_jax import (
NVTE_Fused_Attn_Backend,
get_cudnn_version,
)
from utils import assert_allclose
......@@ -230,7 +233,14 @@ def customcall_fused_dpa(
kwargs.pop("max_segments_per_seq")
return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
return fused_attn_thd(
qkv_args, bias, seqlens_q, seqlens_kv, offsets_q, offsets_kv, dropout_rng, **kwargs
qkv_args,
bias,
seqlens_q,
seqlens_kv,
offsets_q,
offsets_kv,
dropout_rng,
**kwargs,
).astype(query.dtype)
......@@ -265,6 +275,15 @@ class FusedAttnRunner:
qkv_layout: QKVLayout
bias_shape: BiasShape
# See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
def _get_max_segments_per_sequence(self):
if 90400 <= get_cudnn_version() < 90500:
return self.num_segments_per_seq
else:
# +1 for testing runtime_segments < max_segments
return self.num_segments_per_seq + 1
def _check_configs(self):
# TODO(rewang): probably adds this in is_fused_attn_available
if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [
......@@ -299,7 +318,10 @@ class FusedAttnRunner:
self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
and self.bias_shape != BiasShape.BIAS_1HSS
):
if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
if self.attn_mask_type not in [
AttnMaskType.NO_MASK,
AttnMaskType.CAUSAL_MASK,
]:
pytest.skip(
"B1SS, BHSS and 11SS bias shapes are only supported for "
"AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK."
......@@ -316,7 +338,12 @@ class FusedAttnRunner:
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim)
k_shape = v_shape = (
self.batch_size,
self.max_seqlen_kv,
self.num_heads_kv,
self.head_dim,
)
if self.attn_bias_type == AttnBiasType.NO_BIAS:
bias_shape = None
......@@ -325,7 +352,12 @@ class FusedAttnRunner:
elif self.bias_shape == BiasShape.BIAS_B1SS:
bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape.BIAS_BHSS:
bias_shape = (self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
bias_shape = (
self.batch_size,
self.num_heads_q,
self.max_seqlen_q,
self.max_seqlen_kv,
)
elif self.bias_shape == BiasShape.BIAS_11SS:
bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
else:
......@@ -405,7 +437,10 @@ class FusedAttnRunner:
self.segment_pad_kv = self.segment_pad_q
else:
self.token_kv, self.segment_pad_kv = generate_random_segment_ids(
self.batch_size, self.max_seqlen_kv, self.num_segments_per_seq, seed=2024
self.batch_size,
self.max_seqlen_kv,
self.num_segments_per_seq,
seed=2024,
)
self.pad_q = self.segment_pad_q
self.pad_kv = self.segment_pad_kv
......@@ -464,8 +499,7 @@ class FusedAttnRunner:
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
# +1 for testing runtime_segments < max_segments
"max_segments_per_seq": self.num_segments_per_seq + 1,
"max_segments_per_seq": self._get_max_segments_per_sequence(),
}
# Convert the outputs to float32 for the elementwise comparison
......@@ -522,7 +556,7 @@ class FusedAttnRunner:
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self.num_segments_per_seq + 1,
"max_segments_per_seq": self._get_max_segments_per_sequence(),
}
# We can compute dBias only for the [1, h, s, s] layout
......@@ -635,7 +669,16 @@ class FusedAttnRunner:
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
pytest.param(4, 512, 128, 16, 16, 64, jnp.bfloat16, id="4-512-128-16-16-64-BF16-CROSS"),
pytest.param(2, 2048, 1024, 12, 12, 64, jnp.bfloat16, id="2-2048-1048-12-12-64-BF16-CROSS"),
pytest.param(
2,
2048,
1024,
12,
12,
64,
jnp.bfloat16,
id="2-2048-1048-12-12-64-BF16-CROSS",
),
pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment