"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "1b0bd0fe8a4a9de749b9d4618758ff20c8658d86"
Unverified Commit 0994fb48 authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

Fix #1524 and other softmax mask functionality (#1681)



* Add test cases for full coverage in jax/test_layer.py
- causal and window size None
- causal and window size default (-1,1)
- no_mask and window size default (-1,1)
- no_mask and window size default (2,2)
- padding and window size None
- padding_causal and window_size (2,2)
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Correct the condition where padding_causal_mask was being mapped to scaled upper triangle
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Fix Issue #1524
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Add a runner and test cases for jax.flax.module.Softmax class for fwd pass only
Segregate runner classes for Softmax module and softmax primitives
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Simplify logic when picking softmax primitives and softmax jax framework calls
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Simplify the logic for performing jax based softmax
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Code clean up
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add support table for mask, SWA and Softmax type. Code linting
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Explicit SWA conditons in comments. Fix Typo
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Resolve typo to remove None in SWA comments section
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 86928e07
...@@ -215,12 +215,53 @@ ATTRS = [ ...@@ -215,12 +215,53 @@ ATTRS = [
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True, _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
}, },
# attrs22 # attrs22
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
_KEY_OF_WINDOW_SIZE: None,
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
# attrs23
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
# attrs24
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask",
},
# attrs25
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask",
_KEY_OF_WINDOW_SIZE: (2, 2),
},
# attrs26
{ {
_KEY_OF_TRANSPOSE_BS: False, _KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False, _KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding", _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
_KEY_OF_WINDOW_SIZE: (2, 2), _KEY_OF_WINDOW_SIZE: (2, 2),
}, },
# attrs27
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
_KEY_OF_WINDOW_SIZE: None,
},
# attrs28
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_WINDOW_SIZE: (2, 2),
},
] ]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
...@@ -370,13 +411,13 @@ class EncoderRunner(BaseRunner): ...@@ -370,13 +411,13 @@ class EncoderRunner(BaseRunner):
data_rng = jax.random.PRNGKey(2024) data_rng = jax.random.PRNGKey(2024)
inputs = (jax.random.normal(data_rng, data_shape, dtype),) inputs = (jax.random.normal(data_rng, data_shape, dtype),)
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) mask_shape = (batch, 1, seqlen, seqlen)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1) padded_mask = jnp.zeros(mask_shape, dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones(mask_shape, dtype=jnp.uint8), k=1)
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]: if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
mask = causal_mask mask = causal_mask
else: else:
mask = padded_mask mask = padded_mask
ref_masks = (1 - mask,) ref_masks = (1 - mask,)
test_masks = (None, mask) # The second arg of Transformer is encoded tokens. test_masks = (None, mask) # The second arg of Transformer is encoded tokens.
......
...@@ -18,6 +18,7 @@ from utils import assert_allclose ...@@ -18,6 +18,7 @@ from utils import assert_allclose
from transformer_engine.jax.cpp_extensions import is_softmax_kernel_available from transformer_engine.jax.cpp_extensions import is_softmax_kernel_available
from transformer_engine.jax.softmax import SoftmaxType, softmax from transformer_engine.jax.softmax import SoftmaxType, softmax
from transformer_engine.jax.flax.module import Softmax
def catch_unsupported(method): def catch_unsupported(method):
...@@ -94,7 +95,6 @@ class SoftmaxRunner: ...@@ -94,7 +95,6 @@ class SoftmaxRunner:
case _: case _:
raise ValueError(f"Unknown {self.softmax_type=}") raise ValueError(f"Unknown {self.softmax_type=}")
@catch_unsupported
def test_forward(self): def test_forward(self):
""" """
Test transformer_engine.jax.softmax.softmax fwd rule Test transformer_engine.jax.softmax.softmax fwd rule
...@@ -104,7 +104,6 @@ class SoftmaxRunner: ...@@ -104,7 +104,6 @@ class SoftmaxRunner:
reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor) reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor)
assert_allclose(primitive_out, reference_out, dtype=self.dtype) assert_allclose(primitive_out, reference_out, dtype=self.dtype)
@catch_unsupported
def test_backward(self): def test_backward(self):
""" """
Test transformer_engine.jax.softmax.softmax bwd rule Test transformer_engine.jax.softmax.softmax bwd rule
...@@ -141,6 +140,50 @@ class SoftmaxRunner: ...@@ -141,6 +140,50 @@ class SoftmaxRunner:
assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype) assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype)
class SoftmaxPrimitivesRunner(SoftmaxRunner):
"""
Jax Softmax Primitives runner
"""
@catch_unsupported
def test_forward(self):
return super().test_forward()
@catch_unsupported
def test_backward(self):
return super().test_backward()
class SoftmaxModuleRunner:
"""
Jax Softmax Module runner
"""
module_runner: SoftmaxRunner
bias: None
def __init__(self, module_runner, bias):
self.module_runner = module_runner
self.bias = bias
def test_forward(self):
"""
Test transformer_engine.jax.flax.module.Softmax fwd rule
"""
runner = self.module_runner
runner._setup_inputs()
rng = jax.random.PRNGKey(0)
softmax_module = Softmax(
scale_factor=runner.scale_factor,
softmax_type=runner.softmax_type,
)
softmax_vars = softmax_module.init(rng, runner.logits, runner.mask)
module_out = softmax_module.apply(softmax_vars, runner.logits, runner.mask)
reference_out = runner.reference_softmax(runner.logits, runner.mask, runner.scale_factor)
assert_allclose(module_out, reference_out, dtype=runner.dtype)
# Run softmax primitives test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"b, s_q, s_kv, h", "b, s_q, s_kv, h",
[ [
...@@ -165,7 +208,7 @@ class SoftmaxRunner: ...@@ -165,7 +208,7 @@ class SoftmaxRunner:
pytest.param(jnp.float16, id="FP16"), pytest.param(jnp.float16, id="FP16"),
], ],
) )
class TestSoftmax: class TestSoftmaxPrimitives:
""" """
Test transformer_engine.jax.softmax.softmax Test transformer_engine.jax.softmax.softmax
""" """
...@@ -175,7 +218,7 @@ class TestSoftmax: ...@@ -175,7 +218,7 @@ class TestSoftmax:
""" """
Test forward with parameterized configs Test forward with parameterized configs
""" """
runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner.test_forward() runner.test_forward()
@staticmethod @staticmethod
...@@ -183,5 +226,48 @@ class TestSoftmax: ...@@ -183,5 +226,48 @@ class TestSoftmax:
""" """
Test forward with parameterized configs Test forward with parameterized configs
""" """
runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner.test_backward() runner.test_backward()
# Run Softmax module test
@pytest.mark.parametrize(
"b, s_q, s_kv, h",
[
pytest.param(8, 16, 16, 16, id="8-16-16-16"),
pytest.param(8, 512, 512, 16, id="8-512-512-16"),
pytest.param(2, 8, 16384, 8, id="2-8-16384-8"),
# triggers backup framework implementation due to (s_q % 4) != 0
pytest.param(8, 511, 512, 16, id="8-511-512-16"),
],
)
@pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(SoftmaxType.SCALED, id="SCALED"),
pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
],
)
@pytest.mark.parametrize(
"dtype",
[
pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16"),
],
)
class TestSoftmaxModule:
"""
Test transformer_engine.jax.flax.module.Softmax
"""
@staticmethod
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
"""
Test forward with parameterized configs
"""
module_runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
bias = None
runner = SoftmaxModuleRunner(module_runner, bias)
runner.test_forward()
...@@ -31,6 +31,9 @@ __all__ = [ ...@@ -31,6 +31,9 @@ __all__ = [
"scaled_upper_triang_masked_softmax_fwd", "scaled_upper_triang_masked_softmax_fwd",
"scaled_upper_triang_masked_softmax_bwd", "scaled_upper_triang_masked_softmax_bwd",
"is_softmax_kernel_available", "is_softmax_kernel_available",
"jax_scaled_softmax",
"jax_scaled_masked_softmax",
"jax_scaled_upper_triang_masked_softmax",
] ]
...@@ -422,7 +425,7 @@ def scaled_softmax_bwd( ...@@ -422,7 +425,7 @@ def scaled_softmax_bwd(
Return FP16/BF16 tensor Return FP16/BF16 tensor
""" """
if not ScaledSoftmaxBwdPrimitive.enabled(): if not ScaledSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_scaled_softmax, scale_factor=scale_factor), logits) _, vjp_func = jax.vjp(partial(jax_scaled_softmax, scale_factor=scale_factor), logits)
return vjp_func(dz)[0] return vjp_func(dz)[0]
return ScaledSoftmaxBwdPrimitive.outer_primitive.bind( return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(
...@@ -795,11 +798,17 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -795,11 +798,17 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
def _jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float): def jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float):
"""
JAX based implementation of scaled softmax
"""
return jax.nn.softmax(scale_factor * logits) return jax.nn.softmax(scale_factor * logits)
def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float): def jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float):
"""
JAX based implementation of scaled and masked softmax
"""
if mask is not None: if mask is not None:
logits += jax.lax.select( logits += jax.lax.select(
mask > 0, mask > 0,
...@@ -809,7 +818,10 @@ def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_fac ...@@ -809,7 +818,10 @@ def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_fac
return jax.nn.softmax(logits * scale_factor) return jax.nn.softmax(logits * scale_factor)
def _jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float): def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
"""
JAX based implementation of scaled and upper triangle masked softmax
"""
mask = 1 - jnp.tril(jnp.ones_like(logits)) mask = 1 - jnp.tril(jnp.ones_like(logits))
logits += jax.lax.select( logits += jax.lax.select(
mask > 0, mask > 0,
...@@ -825,7 +837,7 @@ def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: ...@@ -825,7 +837,7 @@ def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
Return FP16/BF16 tensor Return FP16/BF16 tensor
""" """
if not ScaledSoftmaxFwdPrimitive.enabled(): if not ScaledSoftmaxFwdPrimitive.enabled():
return _jax_scaled_softmax(logits, scale_factor) return jax_scaled_softmax(logits, scale_factor)
return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor) return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)
...@@ -837,7 +849,7 @@ def scaled_masked_softmax_fwd( ...@@ -837,7 +849,7 @@ def scaled_masked_softmax_fwd(
Return FP16/BF16 tensor Return FP16/BF16 tensor
""" """
if not ScaledMaskedSoftmaxFwdPrimitive.enabled(): if not ScaledMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_masked_softmax(logits, mask, scale_factor) return jax_scaled_masked_softmax(logits, mask, scale_factor)
return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind( return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, mask, scale_factor=scale_factor logits, mask, scale_factor=scale_factor
) )
...@@ -856,7 +868,7 @@ def scaled_masked_softmax_bwd( ...@@ -856,7 +868,7 @@ def scaled_masked_softmax_bwd(
""" """
if not ScaledMaskedSoftmaxBwdPrimitive.enabled(): if not ScaledMaskedSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp( _, vjp_func = jax.vjp(
partial(_jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask partial(jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask
) )
return vjp_func(dz)[0] return vjp_func(dz)[0]
return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind( return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
...@@ -870,7 +882,7 @@ def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: fl ...@@ -870,7 +882,7 @@ def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: fl
Return FP16/BF16 tensor Return FP16/BF16 tensor
""" """
if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled(): if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_upper_triang_masked_softmax(logits, scale_factor) return jax_scaled_upper_triang_masked_softmax(logits, scale_factor)
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind( return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, scale_factor=scale_factor logits, scale_factor=scale_factor
) )
...@@ -885,7 +897,7 @@ def scaled_upper_triang_masked_softmax_bwd( ...@@ -885,7 +897,7 @@ def scaled_upper_triang_masked_softmax_bwd(
""" """
if not ScaledUpperTriangMaskedSoftmaxBwdPrimitive.enabled(): if not ScaledUpperTriangMaskedSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp( _, vjp_func = jax.vjp(
partial(_jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits partial(jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits
) )
return vjp_func(dz)[0] return vjp_func(dz)[0]
return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind( return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
......
...@@ -13,7 +13,6 @@ import jax.numpy as jnp ...@@ -13,7 +13,6 @@ import jax.numpy as jnp
from flax import linen as nn from flax import linen as nn
from flax.linen import partitioning as nn_partitioning from flax.linen import partitioning as nn_partitioning
from jax import lax from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
...@@ -26,7 +25,12 @@ from ..layernorm_mlp import layernorm_mlp ...@@ -26,7 +25,12 @@ from ..layernorm_mlp import layernorm_mlp
from ..activation import activation from ..activation import activation
from ..softmax import softmax, SoftmaxType from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes from ..sharding import with_sharding_constraint_by_logical_axes
from ..cpp_extensions import is_softmax_kernel_available from ..cpp_extensions import (
is_softmax_kernel_available,
jax_scaled_softmax,
jax_scaled_masked_softmax,
jax_scaled_upper_triang_masked_softmax,
)
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
from ..sharding import get_non_contracting_logical_axes from ..sharding import get_non_contracting_logical_axes
...@@ -168,10 +172,10 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods ...@@ -168,10 +172,10 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
input_dtype = inputs.dtype input_dtype = inputs.dtype
logits = inputs logits = inputs
if self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available( # use primitives
if is_softmax_kernel_available(
self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
): ):
if bias is not None: if bias is not None:
logits = logits + bias.astype(input_dtype) logits = logits + bias.astype(input_dtype)
...@@ -180,31 +184,22 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods ...@@ -180,31 +184,22 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
mask_ = None mask_ = None
outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type) outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
# use default jax based implementation
else: else:
attention_bias = None
if mask is not None:
attention_bias = lax.select(
mask > 0,
jnp.full(mask.shape, -1e10),
jnp.full(mask.shape, 0.0),
)
attention_bias = attention_bias.astype(input_dtype)
if bias is not None: if bias is not None:
attention_bias = _combine_biases(attention_bias, bias) logits = logits + bias.astype(input_dtype)
if attention_bias is not None:
logits = logits + attention_bias.astype(input_dtype)
# For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED if self.softmax_type is SoftmaxType.SCALED:
# and kernel is unavailable, then try on pure scaled softmax custom calls. outputs = jax_scaled_softmax(logits, self.scale_factor)
if is_softmax_kernel_available( elif self.softmax_type is SoftmaxType.SCALED_MASKED:
SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, input_dtype outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
): elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED) outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
else: else:
outputs = jax_nn.softmax(logits * self.scale_factor) raise ValueError(
f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
" SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
)
assert input_dtype == outputs.dtype assert input_dtype == outputs.dtype
return outputs return outputs
......
...@@ -220,11 +220,11 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -220,11 +220,11 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
if mask is not None: if mask is not None:
mask = apply_swa_mask(mask) mask = apply_swa_mask(mask)
# Currently cuDNN backend only supports SWA for causal/padding_causal, follow this # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]:
if mask is not None: if mask is not None:
return SoftmaxType.SCALED_MASKED, mask return SoftmaxType.SCALED_MASKED, mask
if attn_mask_type is AttnMaskType.CAUSAL_MASK:
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
if attn_mask_type is AttnMaskType.NO_MASK:
return SoftmaxType.SCALED, mask return SoftmaxType.SCALED, mask
raise ValueError( raise ValueError(
f"Unsupported {attn_mask_type=}, supported attn_mask_type=" f"Unsupported {attn_mask_type=}, supported attn_mask_type="
...@@ -447,6 +447,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -447,6 +447,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
.. note:: THD format only supports 'padding' or 'causal_padding' mask type. .. note:: THD format only supports 'padding' or 'causal_padding' mask type.
attn_mask_type mask/sequence_descriptor SWA softmax type
--------------------------------------------------------------------------------------------
no_mask None None SCALED
causal None None SCALED_UPPER_TRIANG_MASKED
causal None Yes SCALED_MASKED
padding Required Yes/No SCALED_MASKED
padding_causal Required Yes/No SCALED_MASKED
attn_bias_type: Optional[str], default = None attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention. Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment