Unverified Commit 47902e96 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Remove unused TE DPA module dtype which fixes cuDNN backend detection to...


[JAX] Remove unused TE DPA module dtype which fixes cuDNN backend detection to properly use input dtypes (#2485)

* Remove unused TE DPA module dtype which fixes cuDNN backend detection to properly use input dtypes
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Warning fallback
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* adjust test tolerances slightly for encoder tests due to change in backend
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent d46d5db4
......@@ -535,7 +535,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.36 and actual[1] > 0.84
assert actual[0] < 0.361 and actual[1] > 0.84
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self):
......@@ -569,7 +569,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.36 and actual[1] > 0.84
assert actual[0] < 0.361 and actual[1] > 0.84
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp_shardy(self):
......
......@@ -598,6 +598,11 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
(``'zero sink'`` and ``'learnable sink'``).
Optimization parameters
-----------------------
dtype(deprecated): jax.numpy.dtype, default = None
This dtype is deprecated and will be removed in a future release. DPA will use the dtype of the inputs instead as this module does not have any parameters.
"""
head_dim: int
......@@ -606,6 +611,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attention_dropout: float = 0.0
attn_mask_type: AttnMaskType = "causal"
attn_bias_type: AttnBiasType = None
dtype: Optional[DType] = None # Deprecated
dropout_rng_name: str = "dropout"
float32_logits: bool = False
qkv_layout: str = "bshd_bshd_bshd"
......@@ -637,14 +643,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
elif qkv_layout.is_kvpacked():
assert (
key.dtype == query.dtype
), f"Expected kv dtype={key.dtype} to match query dtype={query.dtype}."
), f"Expected kv {key.dtype=} to match query {query.dtype=}."
elif qkv_layout.is_separate():
assert (
key.dtype == query.dtype
), f"Expected key dtype={key.dtype} to match query dtype={query.dtype}."
), f"Expected key {key.dtype=} to match query {query.dtype=}."
assert (
value.dtype == query.dtype
), f"Expected value dtype={value.dtype} to match query dtype={query.dtype}."
), f"Expected value {value.dtype=} to match query {query.dtype=}."
else:
raise ValueError(f"Unsupported {qkv_layout=}.")
......@@ -713,6 +719,22 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
bias = bias.astype(input_dtype)
self._assert_dtypes(query, key, value, qkv_layout)
if self.dtype is not None:
if self.dtype == input_dtype:
warnings.warn(
"The dtype argument is deprecated and will be removed in a future release."
" DotProductAttention will use the dtype of the inputs instead as this module"
f" does not have any parameters. Module dtype specified {self.dtype=} matches"
" dtype of inputs so behavior is unchanged. Please remove the dtype argument"
" within the next few releases."
)
else:
raise ValueError(
"The DotProductAttention module dtype is deprecated and will be removed in a"
" future release. DotProductAttention will use the dtype of the inputs instead"
" as this module does not have any parameters. Module dtype specified"
f" {self.dtype=} does not match dtype of inputs {input_dtype=}."
)
# Use fused attn (if kernel check below passes) by default
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "1"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment