Unverified Commit df699655 authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

[JAX] Fix unit tests to work around cuDNN 9.4 regression of 0 length sequences (#1179)



Modify unit tests to work around cuDNN 9.4 regression.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent c55007b8
...@@ -29,7 +29,10 @@ from transformer_engine.jax.attention import ( ...@@ -29,7 +29,10 @@ from transformer_engine.jax.attention import (
get_qkv_format, get_qkv_format,
) )
from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend from transformer_engine.transformer_engine_jax import (
NVTE_Fused_Attn_Backend,
get_cudnn_version,
)
from utils import assert_allclose from utils import assert_allclose
...@@ -230,7 +233,14 @@ def customcall_fused_dpa( ...@@ -230,7 +233,14 @@ def customcall_fused_dpa(
kwargs.pop("max_segments_per_seq") kwargs.pop("max_segments_per_seq")
return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype) return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
return fused_attn_thd( return fused_attn_thd(
qkv_args, bias, seqlens_q, seqlens_kv, offsets_q, offsets_kv, dropout_rng, **kwargs qkv_args,
bias,
seqlens_q,
seqlens_kv,
offsets_q,
offsets_kv,
dropout_rng,
**kwargs,
).astype(query.dtype) ).astype(query.dtype)
...@@ -265,6 +275,15 @@ class FusedAttnRunner: ...@@ -265,6 +275,15 @@ class FusedAttnRunner:
qkv_layout: QKVLayout qkv_layout: QKVLayout
bias_shape: BiasShape bias_shape: BiasShape
# See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
def _get_max_segments_per_sequence(self):
if 90400 <= get_cudnn_version() < 90500:
return self.num_segments_per_seq
else:
# +1 for testing runtime_segments < max_segments
return self.num_segments_per_seq + 1
def _check_configs(self): def _check_configs(self):
# TODO(rewang): probably adds this in is_fused_attn_available # TODO(rewang): probably adds this in is_fused_attn_available
if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [ if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [
...@@ -299,7 +318,10 @@ class FusedAttnRunner: ...@@ -299,7 +318,10 @@ class FusedAttnRunner:
self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
and self.bias_shape != BiasShape.BIAS_1HSS and self.bias_shape != BiasShape.BIAS_1HSS
): ):
if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: if self.attn_mask_type not in [
AttnMaskType.NO_MASK,
AttnMaskType.CAUSAL_MASK,
]:
pytest.skip( pytest.skip(
"B1SS, BHSS and 11SS bias shapes are only supported for " "B1SS, BHSS and 11SS bias shapes are only supported for "
"AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK." "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK."
...@@ -316,7 +338,12 @@ class FusedAttnRunner: ...@@ -316,7 +338,12 @@ class FusedAttnRunner:
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5) q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim) q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim) k_shape = v_shape = (
self.batch_size,
self.max_seqlen_kv,
self.num_heads_kv,
self.head_dim,
)
if self.attn_bias_type == AttnBiasType.NO_BIAS: if self.attn_bias_type == AttnBiasType.NO_BIAS:
bias_shape = None bias_shape = None
...@@ -325,7 +352,12 @@ class FusedAttnRunner: ...@@ -325,7 +352,12 @@ class FusedAttnRunner:
elif self.bias_shape == BiasShape.BIAS_B1SS: elif self.bias_shape == BiasShape.BIAS_B1SS:
bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv) bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape.BIAS_BHSS: elif self.bias_shape == BiasShape.BIAS_BHSS:
bias_shape = (self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv) bias_shape = (
self.batch_size,
self.num_heads_q,
self.max_seqlen_q,
self.max_seqlen_kv,
)
elif self.bias_shape == BiasShape.BIAS_11SS: elif self.bias_shape == BiasShape.BIAS_11SS:
bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv) bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
else: else:
...@@ -405,7 +437,10 @@ class FusedAttnRunner: ...@@ -405,7 +437,10 @@ class FusedAttnRunner:
self.segment_pad_kv = self.segment_pad_q self.segment_pad_kv = self.segment_pad_q
else: else:
self.token_kv, self.segment_pad_kv = generate_random_segment_ids( self.token_kv, self.segment_pad_kv = generate_random_segment_ids(
self.batch_size, self.max_seqlen_kv, self.num_segments_per_seq, seed=2024 self.batch_size,
self.max_seqlen_kv,
self.num_segments_per_seq,
seed=2024,
) )
self.pad_q = self.segment_pad_q self.pad_q = self.segment_pad_q
self.pad_kv = self.segment_pad_kv self.pad_kv = self.segment_pad_kv
...@@ -464,8 +499,7 @@ class FusedAttnRunner: ...@@ -464,8 +499,7 @@ class FusedAttnRunner:
"dropout_probability": self.dropout_prob, "dropout_probability": self.dropout_prob,
"is_training": self.is_training, "is_training": self.is_training,
"qkv_layout": self.qkv_layout, "qkv_layout": self.qkv_layout,
# +1 for testing runtime_segments < max_segments "max_segments_per_seq": self._get_max_segments_per_sequence(),
"max_segments_per_seq": self.num_segments_per_seq + 1,
} }
# Convert the outputs to float32 for the elementwise comparison # Convert the outputs to float32 for the elementwise comparison
...@@ -522,7 +556,7 @@ class FusedAttnRunner: ...@@ -522,7 +556,7 @@ class FusedAttnRunner:
"dropout_probability": self.dropout_prob, "dropout_probability": self.dropout_prob,
"is_training": self.is_training, "is_training": self.is_training,
"qkv_layout": self.qkv_layout, "qkv_layout": self.qkv_layout,
"max_segments_per_seq": self.num_segments_per_seq + 1, "max_segments_per_seq": self._get_max_segments_per_sequence(),
} }
# We can compute dBias only for the [1, h, s, s] layout # We can compute dBias only for the [1, h, s, s] layout
...@@ -635,7 +669,16 @@ class FusedAttnRunner: ...@@ -635,7 +669,16 @@ class FusedAttnRunner:
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"), pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"), pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
pytest.param(4, 512, 128, 16, 16, 64, jnp.bfloat16, id="4-512-128-16-16-64-BF16-CROSS"), pytest.param(4, 512, 128, 16, 16, 64, jnp.bfloat16, id="4-512-128-16-16-64-BF16-CROSS"),
pytest.param(2, 2048, 1024, 12, 12, 64, jnp.bfloat16, id="2-2048-1048-12-12-64-BF16-CROSS"), pytest.param(
2,
2048,
1024,
12,
12,
64,
jnp.bfloat16,
id="2-2048-1048-12-12-64-BF16-CROSS",
),
pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"), pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"), pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
], ],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment