".github/vscode:/vscode.git/clone" did not exist on "80589625c76b4c774f10bd05d70bb260df3fe978"
Unverified Commit 0994fb48 authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

Fix #1524 and other softmax mask functionality (#1681)



* Add test cases for full coverage in jax/test_layer.py
- causal and window size None
- causal and window size default (-1,1)
- no_mask and window size default (-1,1)
- no_mask and window size default (2,2)
- padding and window size None
- padding_causal and window_size (2,2)
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Correct the condition where padding_causal_mask was being mapped to scaled upper triangle
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Fix Issue #1524
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Add a runner and test cases for jax.flax.module.Softmax class for fwd pass only
Segregate runner classes for Softmax module and softmax primitives
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Simplify logic when picking softmax primitives and softmax jax framework calls
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Simplify the logic for performing jax based softmax
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Code clean up
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add support table for mask, SWA and Softmax type. Code linting
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Explicit SWA conditons in comments. Fix Typo
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Resolve typo to remove None in SWA comments section
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 86928e07
......@@ -215,12 +215,53 @@ ATTRS = [
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
# attrs22
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
_KEY_OF_WINDOW_SIZE: None,
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
# attrs23
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
# attrs24
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask",
},
# attrs25
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask",
_KEY_OF_WINDOW_SIZE: (2, 2),
},
# attrs26
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
_KEY_OF_WINDOW_SIZE: (2, 2),
},
# attrs27
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
_KEY_OF_WINDOW_SIZE: None,
},
# attrs28
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_WINDOW_SIZE: (2, 2),
},
]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
......@@ -370,13 +411,13 @@ class EncoderRunner(BaseRunner):
data_rng = jax.random.PRNGKey(2024)
inputs = (jax.random.normal(data_rng, data_shape, dtype),)
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
mask_shape = (batch, 1, seqlen, seqlen)
padded_mask = jnp.zeros(mask_shape, dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones(mask_shape, dtype=jnp.uint8), k=1)
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
mask = causal_mask
else:
mask = padded_mask
ref_masks = (1 - mask,)
test_masks = (None, mask) # The second arg of Transformer is encoded tokens.
......
......@@ -18,6 +18,7 @@ from utils import assert_allclose
from transformer_engine.jax.cpp_extensions import is_softmax_kernel_available
from transformer_engine.jax.softmax import SoftmaxType, softmax
from transformer_engine.jax.flax.module import Softmax
def catch_unsupported(method):
......@@ -94,7 +95,6 @@ class SoftmaxRunner:
case _:
raise ValueError(f"Unknown {self.softmax_type=}")
@catch_unsupported
def test_forward(self):
"""
Test transformer_engine.jax.softmax.softmax fwd rule
......@@ -104,7 +104,6 @@ class SoftmaxRunner:
reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor)
assert_allclose(primitive_out, reference_out, dtype=self.dtype)
@catch_unsupported
def test_backward(self):
"""
Test transformer_engine.jax.softmax.softmax bwd rule
......@@ -141,6 +140,50 @@ class SoftmaxRunner:
assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype)
class SoftmaxPrimitivesRunner(SoftmaxRunner):
"""
Jax Softmax Primitives runner
"""
@catch_unsupported
def test_forward(self):
return super().test_forward()
@catch_unsupported
def test_backward(self):
return super().test_backward()
class SoftmaxModuleRunner:
"""
Jax Softmax Module runner
"""
module_runner: SoftmaxRunner
bias: None
def __init__(self, module_runner, bias):
self.module_runner = module_runner
self.bias = bias
def test_forward(self):
"""
Test transformer_engine.jax.flax.module.Softmax fwd rule
"""
runner = self.module_runner
runner._setup_inputs()
rng = jax.random.PRNGKey(0)
softmax_module = Softmax(
scale_factor=runner.scale_factor,
softmax_type=runner.softmax_type,
)
softmax_vars = softmax_module.init(rng, runner.logits, runner.mask)
module_out = softmax_module.apply(softmax_vars, runner.logits, runner.mask)
reference_out = runner.reference_softmax(runner.logits, runner.mask, runner.scale_factor)
assert_allclose(module_out, reference_out, dtype=runner.dtype)
# Run softmax primitives test
@pytest.mark.parametrize(
"b, s_q, s_kv, h",
[
......@@ -165,7 +208,7 @@ class SoftmaxRunner:
pytest.param(jnp.float16, id="FP16"),
],
)
class TestSoftmax:
class TestSoftmaxPrimitives:
"""
Test transformer_engine.jax.softmax.softmax
"""
......@@ -175,7 +218,7 @@ class TestSoftmax:
"""
Test forward with parameterized configs
"""
runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner.test_forward()
@staticmethod
......@@ -183,5 +226,48 @@ class TestSoftmax:
"""
Test forward with parameterized configs
"""
runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner.test_backward()
# Run Softmax module test
@pytest.mark.parametrize(
"b, s_q, s_kv, h",
[
pytest.param(8, 16, 16, 16, id="8-16-16-16"),
pytest.param(8, 512, 512, 16, id="8-512-512-16"),
pytest.param(2, 8, 16384, 8, id="2-8-16384-8"),
# triggers backup framework implementation due to (s_q % 4) != 0
pytest.param(8, 511, 512, 16, id="8-511-512-16"),
],
)
@pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(SoftmaxType.SCALED, id="SCALED"),
pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
],
)
@pytest.mark.parametrize(
"dtype",
[
pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16"),
],
)
class TestSoftmaxModule:
"""
Test transformer_engine.jax.flax.module.Softmax
"""
@staticmethod
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
"""
Test forward with parameterized configs
"""
module_runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
bias = None
runner = SoftmaxModuleRunner(module_runner, bias)
runner.test_forward()
......@@ -31,6 +31,9 @@ __all__ = [
"scaled_upper_triang_masked_softmax_fwd",
"scaled_upper_triang_masked_softmax_bwd",
"is_softmax_kernel_available",
"jax_scaled_softmax",
"jax_scaled_masked_softmax",
"jax_scaled_upper_triang_masked_softmax",
]
......@@ -422,7 +425,7 @@ def scaled_softmax_bwd(
Return FP16/BF16 tensor
"""
if not ScaledSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_scaled_softmax, scale_factor=scale_factor), logits)
_, vjp_func = jax.vjp(partial(jax_scaled_softmax, scale_factor=scale_factor), logits)
return vjp_func(dz)[0]
return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(
......@@ -795,11 +798,17 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
def _jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float):
def jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float):
"""
JAX based implementation of scaled softmax
"""
return jax.nn.softmax(scale_factor * logits)
def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float):
def jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float):
"""
JAX based implementation of scaled and masked softmax
"""
if mask is not None:
logits += jax.lax.select(
mask > 0,
......@@ -809,7 +818,10 @@ def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_fac
return jax.nn.softmax(logits * scale_factor)
def _jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
"""
JAX based implementation of scaled and upper triangle masked softmax
"""
mask = 1 - jnp.tril(jnp.ones_like(logits))
logits += jax.lax.select(
mask > 0,
......@@ -825,7 +837,7 @@ def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
Return FP16/BF16 tensor
"""
if not ScaledSoftmaxFwdPrimitive.enabled():
return _jax_scaled_softmax(logits, scale_factor)
return jax_scaled_softmax(logits, scale_factor)
return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)
......@@ -837,7 +849,7 @@ def scaled_masked_softmax_fwd(
Return FP16/BF16 tensor
"""
if not ScaledMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_masked_softmax(logits, mask, scale_factor)
return jax_scaled_masked_softmax(logits, mask, scale_factor)
return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, mask, scale_factor=scale_factor
)
......@@ -856,7 +868,7 @@ def scaled_masked_softmax_bwd(
"""
if not ScaledMaskedSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask
partial(jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask
)
return vjp_func(dz)[0]
return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
......@@ -870,7 +882,7 @@ def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: fl
Return FP16/BF16 tensor
"""
if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_upper_triang_masked_softmax(logits, scale_factor)
return jax_scaled_upper_triang_masked_softmax(logits, scale_factor)
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, scale_factor=scale_factor
)
......@@ -885,7 +897,7 @@ def scaled_upper_triang_masked_softmax_bwd(
"""
if not ScaledUpperTriangMaskedSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits
partial(jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits
)
return vjp_func(dz)[0]
return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
......
......@@ -13,7 +13,6 @@ import jax.numpy as jnp
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name
......@@ -26,7 +25,12 @@ from ..layernorm_mlp import layernorm_mlp
from ..activation import activation
from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes
from ..cpp_extensions import is_softmax_kernel_available
from ..cpp_extensions import (
is_softmax_kernel_available,
jax_scaled_softmax,
jax_scaled_masked_softmax,
jax_scaled_upper_triang_masked_softmax,
)
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
from ..sharding import get_non_contracting_logical_axes
......@@ -168,10 +172,10 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
input_dtype = inputs.dtype
logits = inputs
if self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available(
# use primitives
if is_softmax_kernel_available(
self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
):
if bias is not None:
logits = logits + bias.astype(input_dtype)
......@@ -180,31 +184,22 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
mask_ = None
outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
# use default jax based implementation
else:
attention_bias = None
if mask is not None:
attention_bias = lax.select(
mask > 0,
jnp.full(mask.shape, -1e10),
jnp.full(mask.shape, 0.0),
)
attention_bias = attention_bias.astype(input_dtype)
if bias is not None:
attention_bias = _combine_biases(attention_bias, bias)
if attention_bias is not None:
logits = logits + attention_bias.astype(input_dtype)
logits = logits + bias.astype(input_dtype)
# For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED
# and kernel is unavailable, then try on pure scaled softmax custom calls.
if is_softmax_kernel_available(
SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, input_dtype
):
outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED)
if self.softmax_type is SoftmaxType.SCALED:
outputs = jax_scaled_softmax(logits, self.scale_factor)
elif self.softmax_type is SoftmaxType.SCALED_MASKED:
outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
else:
outputs = jax_nn.softmax(logits * self.scale_factor)
raise ValueError(
f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
" SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
)
assert input_dtype == outputs.dtype
return outputs
......
......@@ -220,11 +220,11 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
if mask is not None:
mask = apply_swa_mask(mask)
# Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
if mask is not None:
return SoftmaxType.SCALED_MASKED, mask
if attn_mask_type is AttnMaskType.CAUSAL_MASK:
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]:
if mask is not None:
return SoftmaxType.SCALED_MASKED, mask
if attn_mask_type is AttnMaskType.NO_MASK:
return SoftmaxType.SCALED, mask
raise ValueError(
f"Unsupported {attn_mask_type=}, supported attn_mask_type="
......@@ -447,6 +447,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
.. note:: THD format only supports 'padding' or 'causal_padding' mask type.
attn_mask_type mask/sequence_descriptor SWA softmax type
--------------------------------------------------------------------------------------------
no_mask None None SCALED
causal None None SCALED_UPPER_TRIANG_MASKED
causal None Yes SCALED_MASKED
padding Required Yes/No SCALED_MASKED
padding_causal Required Yes/No SCALED_MASKED
attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment