Unverified Commit f378eaf2 authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135)



* Fix failing tests for dropout=0.1 and bias for fused attn for blackwell
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix the skip message
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Assert in fused attn bwd pass for sm100
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

Add check for sm100
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add support to get all devs in the process for jax
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Code clean up
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Represent attn bias using enum instead of string
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 3b4366be
...@@ -41,6 +41,7 @@ from transformer_engine.jax.cpp_extensions import FusedAttnHelper ...@@ -41,6 +41,7 @@ from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine_jax import ( from transformer_engine_jax import (
NVTE_Fused_Attn_Backend, NVTE_Fused_Attn_Backend,
get_cudnn_version, get_cudnn_version,
get_device_compute_capability,
) )
from distributed_test_base import assert_equal_collectives from distributed_test_base import assert_equal_collectives
...@@ -348,6 +349,14 @@ class FusedAttnRunner: ...@@ -348,6 +349,14 @@ class FusedAttnRunner:
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
) )
if (
get_device_compute_capability(0) == 100
and self.dropout_prob == 0.1
and self.attn_bias_type is not AttnBiasType.NO_BIAS
):
pytest.skip(
"For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
)
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate(): if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
......
...@@ -34,6 +34,7 @@ from .misc import ( ...@@ -34,6 +34,7 @@ from .misc import (
te_dtype_to_jax_dtype, te_dtype_to_jax_dtype,
get_padded_spec, get_padded_spec,
get_cudnn_version, get_cudnn_version,
get_all_device_compute_capability,
) )
from ..sharding import ( from ..sharding import (
global_mesh_resource, global_mesh_resource,
...@@ -2745,6 +2746,11 @@ def fused_attn_bwd( ...@@ -2745,6 +2746,11 @@ def fused_attn_bwd(
assert bias is None assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype) bias = jnp.zeros(0, dtype=qkv[0].dtype)
if 100 in get_all_device_compute_capability():
assert not (
attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0
), "For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
fused_config = _FusedAttnConfig( fused_config = _FusedAttnConfig(
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
......
...@@ -193,6 +193,16 @@ def get_min_device_compute_capability(): ...@@ -193,6 +193,16 @@ def get_min_device_compute_capability():
) )
def get_all_device_compute_capability():
"""
Returns a list of compute capability of all local devices.
"""
return tuple(
transformer_engine_jax.get_device_compute_capability(local_gpu_id)
for local_gpu_id in range(len(jax.local_devices()))
)
def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quantizer=None): def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quantizer=None):
""" """
Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment