Unverified Commit 5e8a9a96 authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[JAX] Fix: Skip determinism tests for bprop for all sm >=100 (#2315)



* Fix: Skip determinism tests for bprop for all sm >=100
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Add username to TODO
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Assert in fused attn bwd pass for sm100+
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent f0295f9d
......@@ -378,14 +378,14 @@ class FusedAttnRunner:
pytest.skip(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
# TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support
if (
get_device_compute_capability(0) == 100
get_device_compute_capability(0) >= 100
and self.dropout_prob == 0.1
and self.attn_bias_type is not AttnBiasType.NO_BIAS
):
pytest.skip(
"For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
"For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
)
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
......
......@@ -2739,10 +2739,13 @@ def fused_attn_bwd(
assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype)
if 100 in get_all_device_compute_capability():
# TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on
# sm100+
compute_capabilities = get_all_device_compute_capability()
if any(x >= 100 for x in compute_capabilities):
assert not (
attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0
), "For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
fused_config = _FusedAttnConfig(
attn_bias_type=attn_bias_type,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment