Unverified Commit 15cefbc5 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[JAX] Add support for sink attention in JAX (#2225)



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* removed packed versions
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* jax
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix:
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* sofmtax_fusion -> softmax_fusion_type
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 7e593c3b
...@@ -18,6 +18,7 @@ from transformer_engine.jax.attention import ( ...@@ -18,6 +18,7 @@ from transformer_engine.jax.attention import (
is_fused_attn_kernel_available, is_fused_attn_kernel_available,
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
AttnSoftmaxType,
QKVLayout, QKVLayout,
QKVFormat, QKVFormat,
reorder_causal_load_balancing, reorder_causal_load_balancing,
...@@ -66,6 +67,7 @@ class TestDistributedSelfAttn: ...@@ -66,6 +67,7 @@ class TestDistributedSelfAttn:
bias_shape, bias_shape,
attn_mask_type, attn_mask_type,
dtype, dtype,
softmax_type,
use_shardy, use_shardy,
): ):
jax.config.update("jax_use_shardy_partitioner", use_shardy) jax.config.update("jax_use_shardy_partitioner", use_shardy)
...@@ -80,6 +82,7 @@ class TestDistributedSelfAttn: ...@@ -80,6 +82,7 @@ class TestDistributedSelfAttn:
QKVLayout.BS3HD, QKVLayout.BS3HD,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_prob, dropout_prob,
num_head, num_head,
num_head, num_head,
...@@ -109,6 +112,7 @@ class TestDistributedSelfAttn: ...@@ -109,6 +112,7 @@ class TestDistributedSelfAttn:
hidden, hidden,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_prob, dropout_prob,
dtype, dtype,
is_training, is_training,
...@@ -142,6 +146,14 @@ class TestDistributedSelfAttn: ...@@ -142,6 +146,14 @@ class TestDistributedSelfAttn:
], ],
) )
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
def test_self_attn( def test_self_attn(
self, self,
device_count, device_count,
...@@ -153,6 +165,7 @@ class TestDistributedSelfAttn: ...@@ -153,6 +165,7 @@ class TestDistributedSelfAttn:
bias_shape, bias_shape,
attn_mask_type, attn_mask_type,
dtype, dtype,
softmax_type,
): ):
self.impl_test_self_attn( self.impl_test_self_attn(
device_count, device_count,
...@@ -164,6 +177,7 @@ class TestDistributedSelfAttn: ...@@ -164,6 +177,7 @@ class TestDistributedSelfAttn:
bias_shape, bias_shape,
attn_mask_type, attn_mask_type,
dtype, dtype,
softmax_type,
use_shardy=False, use_shardy=False,
) )
...@@ -175,8 +189,23 @@ class TestDistributedSelfAttn: ...@@ -175,8 +189,23 @@ class TestDistributedSelfAttn:
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"), pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
], ],
) )
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
def test_self_attn_shardy( def test_self_attn_shardy(
self, device_count, mesh_shape, mesh_axes, mesh_resource, attn_bias_type, bias_shape self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
attn_bias_type,
bias_shape,
softmax_type,
): ):
data_shape = (32, 512, 12, 64) data_shape = (32, 512, 12, 64)
self.impl_test_self_attn( self.impl_test_self_attn(
...@@ -189,6 +218,7 @@ class TestDistributedSelfAttn: ...@@ -189,6 +218,7 @@ class TestDistributedSelfAttn:
bias_shape, bias_shape,
AttnMaskType.PADDING_MASK, AttnMaskType.PADDING_MASK,
jnp.bfloat16, jnp.bfloat16,
softmax_type,
use_shardy=True, use_shardy=True,
) )
...@@ -213,8 +243,24 @@ class TestDistributedCrossAttn: ...@@ -213,8 +243,24 @@ class TestDistributedCrossAttn:
"attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK] "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
) )
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
def test_cross_attn( def test_cross_attn(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_mask_type,
dtype,
softmax_type,
): ):
attn_bias_type = AttnBiasType.NO_BIAS attn_bias_type = AttnBiasType.NO_BIAS
bias_shape = None bias_shape = None
...@@ -230,6 +276,7 @@ class TestDistributedCrossAttn: ...@@ -230,6 +276,7 @@ class TestDistributedCrossAttn:
QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BS2HD,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_prob, dropout_prob,
num_head, num_head,
num_head, num_head,
...@@ -252,6 +299,7 @@ class TestDistributedCrossAttn: ...@@ -252,6 +299,7 @@ class TestDistributedCrossAttn:
hidden, hidden,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_prob, dropout_prob,
dtype, dtype,
is_training, is_training,
...@@ -322,6 +370,8 @@ class TestDistributedContextParallelSelfAttn: ...@@ -322,6 +370,8 @@ class TestDistributedContextParallelSelfAttn:
bias_shape = None bias_shape = None
dropout_prob = 0.0 dropout_prob = 0.0
is_training = True is_training = True
# Context parallel does not support softmax_offset
softmax_type = AttnSoftmaxType.VANILLA_SOFTMAX
dp_size, cp_size, tp_size = mesh_shape dp_size, cp_size, tp_size = mesh_shape
batch, seqlen, num_head, hidden = data_shape batch, seqlen, num_head, hidden = data_shape
...@@ -343,6 +393,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -343,6 +393,7 @@ class TestDistributedContextParallelSelfAttn:
hidden, hidden,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_prob, dropout_prob,
dtype, dtype,
is_training, is_training,
...@@ -366,6 +417,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -366,6 +417,7 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
mask_type, mask_type,
softmax_type,
dropout_prob, dropout_prob,
num_head, num_head,
num_kv_heads, num_kv_heads,
......
...@@ -16,7 +16,7 @@ from distributed_test_base import generate_configs, generate_collectives_count ...@@ -16,7 +16,7 @@ from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import autocast from transformer_engine.jax import autocast
from transformer_engine.jax.softmax import SoftmaxType, softmax from transformer_engine.jax.softmax import SoftmaxFusionType, softmax
DTYPES = [jnp.float16, jnp.bfloat16] DTYPES = [jnp.float16, jnp.bfloat16]
...@@ -29,12 +29,12 @@ class TestDistributedSoftmax: ...@@ -29,12 +29,12 @@ class TestDistributedSoftmax:
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs( def generate_inputs(
self, shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask self, shape, mesh_resource, softmax_fusion_type, dtype, bad_sharding, broadcast_batch_mask
): ):
batch, _, sqelen, _ = shape batch, _, sqelen, _ = shape
x = random.normal(random.PRNGKey(1124), shape, dtype=dtype) x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
if softmax_type == SoftmaxType.SCALED_UPPER_TRIANG_MASKED: if softmax_fusion_type == SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
mask = make_causal_mask(batch, sqelen) mask = make_causal_mask(batch, sqelen)
else: else:
mask = make_self_mask(1 if broadcast_batch_mask else batch, sqelen) mask = make_self_mask(1 if broadcast_batch_mask else batch, sqelen)
...@@ -56,8 +56,10 @@ class TestDistributedSoftmax: ...@@ -56,8 +56,10 @@ class TestDistributedSoftmax:
return (x, mask), (x_pspec, mask_pspec) return (x, mask), (x_pspec, mask_pspec)
@staticmethod @staticmethod
def target_func(x, mask, scale_factor=1.0, softmax_type=SoftmaxType.SCALED): def target_func(x, mask, scale_factor=1.0, softmax_fusion_type=SoftmaxFusionType.SCALED):
return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type)) return jnp.mean(
softmax(x, mask, scale_factor=scale_factor, softmax_fusion_type=softmax_fusion_type)
)
@staticmethod @staticmethod
def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16): def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16):
...@@ -80,24 +82,29 @@ class TestDistributedSoftmax: ...@@ -80,24 +82,29 @@ class TestDistributedSoftmax:
mesh_axes, mesh_axes,
mesh_resource, mesh_resource,
data_shape, data_shape,
softmax_type, softmax_fusion_type,
scale_factor, scale_factor,
dtype, dtype,
bad_sharding, bad_sharding,
broadcast_batch_mask, broadcast_batch_mask,
use_shardy, use_shardy,
): ):
if broadcast_batch_mask and softmax_type != SoftmaxType.SCALED_MASKED: if broadcast_batch_mask and softmax_fusion_type != SoftmaxFusionType.SCALED_MASKED:
pytest.skip("Softmax type has no mask.") pytest.skip("Softmax type has no mask.")
jax.config.update("jax_use_shardy_partitioner", use_shardy) jax.config.update("jax_use_shardy_partitioner", use_shardy)
target_func = partial( target_func = partial(
self.target_func, scale_factor=scale_factor, softmax_type=softmax_type self.target_func, scale_factor=scale_factor, softmax_fusion_type=softmax_fusion_type
) )
ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype) ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype)
(x, mask), (x_pspec, mask_pspec) = self.generate_inputs( (x, mask), (x_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask data_shape,
mesh_resource,
softmax_fusion_type,
dtype,
bad_sharding,
broadcast_batch_mask,
) )
collective_count_ref = self.generate_collectives_count_ref() collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
...@@ -139,8 +146,12 @@ class TestDistributedSoftmax: ...@@ -139,8 +146,12 @@ class TestDistributedSoftmax:
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [8, 8, 1024, 1024]]) @pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [8, 8, 1024, 1024]])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"softmax_type", "softmax_fusion_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED], [
SoftmaxFusionType.SCALED,
SoftmaxFusionType.SCALED_MASKED,
SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED,
],
) )
@pytest.mark.parametrize("scale_factor", [1.0, 3.0]) @pytest.mark.parametrize("scale_factor", [1.0, 3.0])
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
...@@ -153,7 +164,7 @@ class TestDistributedSoftmax: ...@@ -153,7 +164,7 @@ class TestDistributedSoftmax:
mesh_axes, mesh_axes,
mesh_resource, mesh_resource,
data_shape, data_shape,
softmax_type, softmax_fusion_type,
scale_factor, scale_factor,
dtype, dtype,
bad_sharding, bad_sharding,
...@@ -165,7 +176,7 @@ class TestDistributedSoftmax: ...@@ -165,7 +176,7 @@ class TestDistributedSoftmax:
mesh_axes, mesh_axes,
mesh_resource, mesh_resource,
data_shape, data_shape,
softmax_type, softmax_fusion_type,
scale_factor, scale_factor,
dtype, dtype,
bad_sharding, bad_sharding,
...@@ -174,7 +185,9 @@ class TestDistributedSoftmax: ...@@ -174,7 +185,9 @@ class TestDistributedSoftmax:
) )
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED]) @pytest.mark.parametrize(
"softmax_fusion_type", [SoftmaxFusionType.SCALED, SoftmaxFusionType.SCALED_MASKED]
)
@pytest.mark.parametrize("bad_sharding", [False, True]) @pytest.mark.parametrize("bad_sharding", [False, True])
@pytest.mark.parametrize("broadcast_batch_mask", [False, True]) @pytest.mark.parametrize("broadcast_batch_mask", [False, True])
def test_softmax_gspmd( def test_softmax_gspmd(
...@@ -183,7 +196,7 @@ class TestDistributedSoftmax: ...@@ -183,7 +196,7 @@ class TestDistributedSoftmax:
mesh_shape, mesh_shape,
mesh_axes, mesh_axes,
mesh_resource, mesh_resource,
softmax_type, softmax_fusion_type,
bad_sharding, bad_sharding,
broadcast_batch_mask, broadcast_batch_mask,
): ):
...@@ -193,7 +206,7 @@ class TestDistributedSoftmax: ...@@ -193,7 +206,7 @@ class TestDistributedSoftmax:
mesh_axes, mesh_axes,
mesh_resource, mesh_resource,
data_shape=[32, 12, 128, 128], data_shape=[32, 12, 128, 128],
softmax_type=softmax_type, softmax_fusion_type=softmax_fusion_type,
scale_factor=1.0, scale_factor=1.0,
dtype=DTYPES[0], dtype=DTYPES[0],
bad_sharding=bad_sharding, bad_sharding=bad_sharding,
......
...@@ -27,6 +27,7 @@ from transformer_engine.jax.sharding import MeshResource ...@@ -27,6 +27,7 @@ from transformer_engine.jax.sharding import MeshResource
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
AttnSoftmaxType,
QKVLayout, QKVLayout,
QKVFormat, QKVFormat,
reorder_causal_load_balancing, reorder_causal_load_balancing,
...@@ -59,14 +60,16 @@ def init(): ...@@ -59,14 +60,16 @@ def init():
yield yield
@partial(jax.jit, static_argnums=(5, 6, 7, 9)) @partial(jax.jit, static_argnums=(6, 7, 8, 9, 11))
def general_dot_product_attention( def general_dot_product_attention(
query: ArrayLike, query: ArrayLike,
key: ArrayLike, key: ArrayLike,
value: ArrayLike, value: ArrayLike,
softmax_offset: Optional[ArrayLike],
bias: ArrayLike, bias: ArrayLike,
mask: ArrayLike, mask: ArrayLike,
deterministic: bool, deterministic: bool,
softmax_type: AttnSoftmaxType,
scale_factor: float, scale_factor: float,
dropout_rate: float, dropout_rate: float,
dropout_rng: ArrayLike, dropout_rng: ArrayLike,
...@@ -99,7 +102,25 @@ def general_dot_product_attention( ...@@ -99,7 +102,25 @@ def general_dot_product_attention(
mask = jnp.expand_dims(mask, axis=-3) mask = jnp.expand_dims(mask, axis=-3)
logits = jnp.where(mask, jnp.finfo(dtype).min, logits) logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
softmax_out = jax.nn.softmax(logits).astype(dtype) match softmax_type:
case AttnSoftmaxType.VANILLA_SOFTMAX:
softmax_out = jax.nn.softmax(logits).astype(dtype)
case AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
# Softmax with +1 in denominator: exp(x_i) / (sum(exp(x_j)) + 1)
# Append a zero logit, apply standard softmax, then remove last column
zero_logit = jnp.zeros(logits.shape[:-1] + (1,), dtype=logits.dtype)
logits_with_extra = jnp.concatenate([logits, zero_logit], axis=-1)
softmax_with_extra = jax.nn.softmax(logits_with_extra, axis=-1)
softmax_out = softmax_with_extra[..., :-1].astype(dtype)
case AttnSoftmaxType.LEARNABLE_SOFTMAX:
# Append learnable offset logit, apply standard softmax, then remove last column
learnable_logit = softmax_offset.reshape(1, h_kv, num_groups, 1, 1)
learnable_logit = jnp.broadcast_to(learnable_logit, logits.shape[:-1] + (1,))
logits_with_extra = jnp.concatenate([logits, learnable_logit], axis=-1)
softmax_with_extra = jax.nn.softmax(logits_with_extra, axis=-1)
softmax_out = softmax_with_extra[..., :-1].astype(dtype)
case _:
raise NotImplementedError(f"Unknown {softmax_type=}")
if not deterministic and dropout_rate > 0.0: if not deterministic and dropout_rate > 0.0:
keep_prob = 1.0 - dropout_rate keep_prob = 1.0 - dropout_rate
...@@ -238,7 +259,7 @@ def _split_valid_and_invalid(primitive, reference, pad): ...@@ -238,7 +259,7 @@ def _split_valid_and_invalid(primitive, reference, pad):
return primitive_valid, primitive_invalid, reference_valid, reference_invalid return primitive_valid, primitive_invalid, reference_valid, reference_invalid
def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs): def jax_dpa(query, key, value, bias, softmax_offset, mask, dropout_rng, **kwargs):
""" """
JAX native dot product attention implementation JAX native dot product attention implementation
""" """
...@@ -246,11 +267,13 @@ def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs): ...@@ -246,11 +267,13 @@ def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
query, query,
key, key,
value, value,
softmax_offset,
bias, bias,
mask, mask,
deterministic=not kwargs["is_training"], deterministic=not kwargs["is_training"],
scale_factor=kwargs["scaling_factor"], scale_factor=kwargs["scaling_factor"],
dropout_rate=kwargs["dropout_probability"], dropout_rate=kwargs["dropout_probability"],
softmax_type=kwargs["softmax_type"],
dropout_rng=dropout_rng, dropout_rng=dropout_rng,
dtype=jnp.float32, dtype=jnp.float32,
) )
...@@ -262,6 +285,7 @@ def customcall_fused_dpa( ...@@ -262,6 +285,7 @@ def customcall_fused_dpa(
key, key,
value, value,
bias, bias,
softmax_offset,
sequence_descriptor, sequence_descriptor,
dropout_rng, dropout_rng,
**kwargs, **kwargs,
...@@ -283,9 +307,9 @@ def customcall_fused_dpa( ...@@ -283,9 +307,9 @@ def customcall_fused_dpa(
qkv_args = (query, key, value) qkv_args = (query, key, value)
case _: case _:
raise ValueError(f"Unsupported {qkv_layout=}") raise ValueError(f"Unsupported {qkv_layout=}")
return fused_attn(qkv_args, bias, sequence_descriptor, dropout_rng, **kwargs).astype( return fused_attn(
query.dtype qkv_args, bias, sequence_descriptor, dropout_rng, softmax_offset=softmax_offset, **kwargs
) ).astype(query.dtype)
class BiasShape(Enum): class BiasShape(Enum):
...@@ -320,6 +344,7 @@ class FusedAttnRunner: ...@@ -320,6 +344,7 @@ class FusedAttnRunner:
head_dim_v: int head_dim_v: int
attn_bias_type: AttnBiasType attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType attn_mask_type: AttnMaskType
softmax_type: AttnSoftmaxType
dropout_prob: float dropout_prob: float
dtype: DTypeLike dtype: DTypeLike
is_training: bool is_training: bool
...@@ -402,6 +427,7 @@ class FusedAttnRunner: ...@@ -402,6 +427,7 @@ class FusedAttnRunner:
self.qkv_layout, self.qkv_layout,
self.attn_bias_type, self.attn_bias_type,
self.attn_mask_type, self.attn_mask_type,
self.softmax_type,
self.dropout_prob, self.dropout_prob,
self.num_heads_q, self.num_heads_q,
self.num_heads_kv, self.num_heads_kv,
...@@ -439,7 +465,7 @@ class FusedAttnRunner: ...@@ -439,7 +465,7 @@ class FusedAttnRunner:
self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1) self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1)
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5) q_key, k_key, v_key, bias_key, dropout_key, softmax_key = jax.random.split(key, 6)
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk) q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk) k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
...@@ -490,6 +516,13 @@ class FusedAttnRunner: ...@@ -490,6 +516,13 @@ class FusedAttnRunner:
else: else:
pad_ratio = 0.0 pad_ratio = 0.0
if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
self.softmax_offset = jax.random.uniform(
softmax_key, (1, self.num_heads_q, 1, 1), jnp.float32, -1.0
)
else:
self.softmax_offset = None
def gen_valid(bs, max_seqlen, pad_ratio): def gen_valid(bs, max_seqlen, pad_ratio):
pad_len = int(max_seqlen * pad_ratio) pad_len = int(max_seqlen * pad_ratio)
valid_len = max_seqlen - pad_len valid_len = max_seqlen - pad_len
...@@ -713,6 +746,16 @@ class FusedAttnRunner: ...@@ -713,6 +746,16 @@ class FusedAttnRunner:
self.bias_pspec = PartitionSpec() self.bias_pspec = PartitionSpec()
self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec) self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)
# Softmax offset sharding (1, num_heads, 1, 1)
# Use the same logic as HEAD_AXES: tpsp_resource if enabled, else tp_resource
head_resource = (
self.mesh_resource.tpsp_resource
if self.mesh_resource.tpsp_resource is not None
else self.mesh_resource.tp_resource
)
self.softmax_offset_pspec = PartitionSpec(None, head_resource, None, None)
self.softmax_offset_sharding = NamedSharding(self.mesh, self.softmax_offset_pspec)
self.dropout_rng_pspec = PartitionSpec( self.dropout_rng_pspec = PartitionSpec(
None, None,
) )
...@@ -732,7 +775,7 @@ class FusedAttnRunner: ...@@ -732,7 +775,7 @@ class FusedAttnRunner:
""" """
self._setup_inputs() self._setup_inputs()
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng] args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
customcall_args = [ customcall_args = [
# Put test data onto each GPU for distributed. # Put test data onto each GPU for distributed.
...@@ -742,12 +785,14 @@ class FusedAttnRunner: ...@@ -742,12 +785,14 @@ class FusedAttnRunner:
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding), jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding), jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding), jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
] ]
kwargs = { kwargs = {
"attn_bias_type": self.attn_bias_type, "attn_bias_type": self.attn_bias_type,
"attn_mask_type": self.attn_mask_type, "attn_mask_type": self.attn_mask_type,
"softmax_type": self.softmax_type,
"scaling_factor": self.scaling_factor, "scaling_factor": self.scaling_factor,
"dropout_probability": self.dropout_prob, "dropout_probability": self.dropout_prob,
"is_training": self.is_training, "is_training": self.is_training,
...@@ -766,6 +811,7 @@ class FusedAttnRunner: ...@@ -766,6 +811,7 @@ class FusedAttnRunner:
self.qkvo_sharding, self.qkvo_sharding,
self.qkvo_sharding, self.qkvo_sharding,
self.bias_sharding, self.bias_sharding,
self.softmax_offset_sharding,
self.seq_desc_sharding, self.seq_desc_sharding,
self.dropout_rng_sharding, self.dropout_rng_sharding,
], ],
...@@ -826,7 +872,7 @@ class FusedAttnRunner: ...@@ -826,7 +872,7 @@ class FusedAttnRunner:
jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
).astype(self.dtype) ).astype(self.dtype)
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng] args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
customcall_args = [ customcall_args = [
# TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
# THD params once we support those features on CP. # THD params once we support those features on CP.
...@@ -834,12 +880,14 @@ class FusedAttnRunner: ...@@ -834,12 +880,14 @@ class FusedAttnRunner:
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding), jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding), jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding), jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
] ]
kwargs = { kwargs = {
"attn_bias_type": self.attn_bias_type, "attn_bias_type": self.attn_bias_type,
"attn_mask_type": self.attn_mask_type, "attn_mask_type": self.attn_mask_type,
"softmax_type": self.softmax_type,
"scaling_factor": self.scaling_factor, "scaling_factor": self.scaling_factor,
"dropout_probability": self.dropout_prob, "dropout_probability": self.dropout_prob,
"is_training": self.is_training, "is_training": self.is_training,
...@@ -866,8 +914,16 @@ class FusedAttnRunner: ...@@ -866,8 +914,16 @@ class FusedAttnRunner:
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit( jitted_primitive = jit(
value_and_grad( value_and_grad(
lambda q, k, v, bias, *args: grad_func( lambda q, k, v, bias, softmax_offset, *args: grad_func(
customcall_fused_dpa, q, k, v, bias, *args, cp_reverse_out=True, **kwargs customcall_fused_dpa,
q,
k,
v,
bias,
softmax_offset,
*args,
cp_reverse_out=True,
**kwargs,
), ),
arg_nums, arg_nums,
), ),
...@@ -876,6 +932,7 @@ class FusedAttnRunner: ...@@ -876,6 +932,7 @@ class FusedAttnRunner:
self.qkvo_sharding, self.qkvo_sharding,
self.qkvo_sharding, self.qkvo_sharding,
self.bias_sharding, self.bias_sharding,
self.softmax_offset_sharding,
self.seq_desc_sharding, self.seq_desc_sharding,
self.dropout_rng_sharding, self.dropout_rng_sharding,
), ),
...@@ -883,7 +940,9 @@ class FusedAttnRunner: ...@@ -883,7 +940,9 @@ class FusedAttnRunner:
) )
jitted_reference = jit( jitted_reference = jit(
value_and_grad( value_and_grad(
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs), lambda q, k, v, bias, softmax_offset, *args: grad_func(
jax_dpa, q, k, v, bias, softmax_offset, *args, **kwargs
),
arg_nums, arg_nums,
) )
) )
...@@ -976,6 +1035,14 @@ class FusedAttnRunner: ...@@ -976,6 +1035,14 @@ class FusedAttnRunner:
), ),
], ],
) )
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"qkv_layout", "qkv_layout",
[ [
...@@ -1084,6 +1151,7 @@ class TestFusedAttn: ...@@ -1084,6 +1151,7 @@ class TestFusedAttn:
d_v, d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_prob, dropout_prob,
dtype, dtype,
is_training, is_training,
...@@ -1110,6 +1178,7 @@ class TestFusedAttn: ...@@ -1110,6 +1178,7 @@ class TestFusedAttn:
d_v, d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_prob, dropout_prob,
dtype, dtype,
is_training, is_training,
...@@ -1138,6 +1207,7 @@ class TestFusedAttn: ...@@ -1138,6 +1207,7 @@ class TestFusedAttn:
d_v, d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_prob, dropout_prob,
dtype, dtype,
qkv_layout, qkv_layout,
...@@ -1161,6 +1231,7 @@ class TestFusedAttn: ...@@ -1161,6 +1231,7 @@ class TestFusedAttn:
d_v, d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_prob, dropout_prob,
dtype, dtype,
True, True,
......
...@@ -83,6 +83,7 @@ _KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits" ...@@ -83,6 +83,7 @@ _KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits"
_KEY_OF_USE_BIAS = "use_bias" _KEY_OF_USE_BIAS = "use_bias"
_KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding" _KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding"
_KEY_OF_WINDOW_SIZE = "window_size" _KEY_OF_WINDOW_SIZE = "window_size"
_KEY_OF_SOFTMAX_TYPE = "softmax_type"
BASE_ATTRS = { BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True, _KEY_OF_TRANSPOSE_BS: True,
...@@ -276,6 +277,14 @@ ATTRS = [ ...@@ -276,6 +277,14 @@ ATTRS = [
_KEY_OF_RELATIVE_EMBEDDING: True, _KEY_OF_RELATIVE_EMBEDDING: True,
_KEY_OF_SELF_ATTN_BIAS_TYPE: "post_scale_bias", _KEY_OF_SELF_ATTN_BIAS_TYPE: "post_scale_bias",
}, },
# attrs31
{
_KEY_OF_SOFTMAX_TYPE: "off_by_one",
},
# attrs31
{
_KEY_OF_SOFTMAX_TYPE: "learnable",
},
] ]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
...@@ -418,6 +427,9 @@ class EncoderRunner(BaseRunner): ...@@ -418,6 +427,9 @@ class EncoderRunner(BaseRunner):
"attention/qkv/ln_bias": "pre_attention_layer_norm/ln_bias", "attention/qkv/ln_bias": "pre_attention_layer_norm/ln_bias",
"attention/query/scale": "pre_attention_layer_norm/scale", "attention/query/scale": "pre_attention_layer_norm/scale",
"attention/query/ln_bias": "pre_attention_layer_norm/ln_bias", "attention/query/ln_bias": "pre_attention_layer_norm/ln_bias",
"attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
"attention/DotProductAttention_0/softmax_offset"
),
"mlp/wi_kernel": "mlp/wi/kernel", "mlp/wi_kernel": "mlp/wi/kernel",
"mlp/wi_bias": "mlp/wi/bias", "mlp/wi_bias": "mlp/wi/bias",
"mlp/wo_kernel": "mlp/wo/kernel", "mlp/wo_kernel": "mlp/wo/kernel",
...@@ -463,10 +475,16 @@ class DecoderRunner(BaseRunner): ...@@ -463,10 +475,16 @@ class DecoderRunner(BaseRunner):
"encoder_decoder_attention/qkv/ln_bias": "pre_cross_attention_layer_norm/ln_bias", "encoder_decoder_attention/qkv/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
"encoder_decoder_attention/query/scale": "pre_cross_attention_layer_norm/scale", "encoder_decoder_attention/query/scale": "pre_cross_attention_layer_norm/scale",
"encoder_decoder_attention/query/ln_bias": "pre_cross_attention_layer_norm/ln_bias", "encoder_decoder_attention/query/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
"encoder_decoder_attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
"encoder_decoder_attention/DotProductAttention_0/softmax_offset"
),
"self_attention/qkv/scale": "pre_self_attention_layer_norm/scale", "self_attention/qkv/scale": "pre_self_attention_layer_norm/scale",
"self_attention/qkv/ln_bias": "pre_self_attention_layer_norm/ln_bias", "self_attention/qkv/ln_bias": "pre_self_attention_layer_norm/ln_bias",
"self_attention/query/scale": "pre_self_attention_layer_norm/scale", "self_attention/query/scale": "pre_self_attention_layer_norm/scale",
"self_attention/query/ln_bias": "pre_self_attention_layer_norm/ln_bias", "self_attention/query/ln_bias": "pre_self_attention_layer_norm/ln_bias",
"self_attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
"self_attention/DotProductAttention_0/softmax_offset"
),
"mlp/wi_kernel": "mlp/wi/kernel", "mlp/wi_kernel": "mlp/wi/kernel",
"mlp/wi_bias": "mlp/wi/bias", "mlp/wi_bias": "mlp/wi/bias",
"mlp/wo_kernel": "mlp/wo/kernel", "mlp/wo_kernel": "mlp/wo/kernel",
......
...@@ -17,7 +17,8 @@ from jax.typing import DTypeLike ...@@ -17,7 +17,8 @@ from jax.typing import DTypeLike
from utils import assert_allclose from utils import assert_allclose
from transformer_engine.jax.cpp_extensions import is_softmax_kernel_available from transformer_engine.jax.cpp_extensions import is_softmax_kernel_available
from transformer_engine.jax.softmax import SoftmaxType, softmax from transformer_engine.jax.cpp_extensions.attention import AttnSoftmaxType
from transformer_engine.jax.softmax import SoftmaxFusionType, softmax
from transformer_engine.jax.flax.module import Softmax from transformer_engine.jax.flax.module import Softmax
...@@ -50,8 +51,9 @@ class SoftmaxRunner: ...@@ -50,8 +51,9 @@ class SoftmaxRunner:
max_seqlen_kv: int max_seqlen_kv: int
num_heads: int num_heads: int
scale_factor: float scale_factor: float
softmax_type: SoftmaxType softmax_fusion_type: SoftmaxFusionType
dtype: DTypeLike dtype: DTypeLike
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
@staticmethod @staticmethod
def reference_softmax(logits, mask, scale_factor, **_): def reference_softmax(logits, mask, scale_factor, **_):
...@@ -68,6 +70,7 @@ class SoftmaxRunner: ...@@ -68,6 +70,7 @@ class SoftmaxRunner:
def _is_support(self): def _is_support(self):
return is_softmax_kernel_available( return is_softmax_kernel_available(
self.softmax_fusion_type,
self.softmax_type, self.softmax_type,
self.batch_size, self.batch_size,
self.num_heads, self.num_heads,
...@@ -85,22 +88,22 @@ class SoftmaxRunner: ...@@ -85,22 +88,22 @@ class SoftmaxRunner:
self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.0) self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.0)
match self.softmax_type: match self.softmax_fusion_type:
case SoftmaxType.SCALED: case SoftmaxFusionType.SCALED:
self.mask = None self.mask = None
case SoftmaxType.SCALED_MASKED: case SoftmaxFusionType.SCALED_MASKED:
self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8) self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8)
case SoftmaxType.SCALED_UPPER_TRIANG_MASKED: case SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
self.mask = (1.0 - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8) self.mask = (1.0 - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8)
case _: case _:
raise ValueError(f"Unknown {self.softmax_type=}") raise ValueError(f"Unknown {self.softmax_fusion_type=}")
def test_forward(self): def test_forward(self):
""" """
Test transformer_engine.jax.softmax.softmax fwd rule Test transformer_engine.jax.softmax.softmax fwd rule
""" """
self._setup_inputs() self._setup_inputs()
primitive_out = softmax(self.logits, self.mask, self.scale_factor, self.softmax_type) primitive_out = softmax(self.logits, self.mask, self.scale_factor, self.softmax_fusion_type)
reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor) reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor)
assert_allclose(primitive_out, reference_out, dtype=self.dtype) assert_allclose(primitive_out, reference_out, dtype=self.dtype)
...@@ -117,7 +120,7 @@ class SoftmaxRunner: ...@@ -117,7 +120,7 @@ class SoftmaxRunner:
args = [self.logits, self.mask] args = [self.logits, self.mask]
kwargs = { kwargs = {
"scale_factor": self.scale_factor, "scale_factor": self.scale_factor,
"softmax_type": self.softmax_type, "softmax_fusion_type": self.softmax_fusion_type,
} }
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
...@@ -175,7 +178,7 @@ class SoftmaxModuleRunner: ...@@ -175,7 +178,7 @@ class SoftmaxModuleRunner:
rng = jax.random.PRNGKey(0) rng = jax.random.PRNGKey(0)
softmax_module = Softmax( softmax_module = Softmax(
scale_factor=runner.scale_factor, scale_factor=runner.scale_factor,
softmax_type=runner.softmax_type, softmax_fusion_type=runner.softmax_fusion_type,
) )
softmax_vars = softmax_module.init(rng, runner.logits, runner.mask) softmax_vars = softmax_module.init(rng, runner.logits, runner.mask)
module_out = softmax_module.apply(softmax_vars, runner.logits, runner.mask) module_out = softmax_module.apply(softmax_vars, runner.logits, runner.mask)
...@@ -194,11 +197,11 @@ class SoftmaxModuleRunner: ...@@ -194,11 +197,11 @@ class SoftmaxModuleRunner:
) )
@pytest.mark.parametrize("scale_factor", [0.125]) @pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"softmax_type", "softmax_fusion_type",
[ [
pytest.param(SoftmaxType.SCALED, id="SCALED"), pytest.param(SoftmaxFusionType.SCALED, id="SCALED"),
pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"), pytest.param(SoftmaxFusionType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"), pytest.param(SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -214,19 +217,19 @@ class TestSoftmaxPrimitives: ...@@ -214,19 +217,19 @@ class TestSoftmaxPrimitives:
""" """
@staticmethod @staticmethod
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype): def test_forward(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype):
""" """
Test forward with parameterized configs Test forward with parameterized configs
""" """
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype)
runner.test_forward() runner.test_forward()
@staticmethod @staticmethod
def test_backward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype): def test_backward(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype):
""" """
Test forward with parameterized configs Test forward with parameterized configs
""" """
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype)
runner.test_backward() runner.test_backward()
...@@ -243,11 +246,11 @@ class TestSoftmaxPrimitives: ...@@ -243,11 +246,11 @@ class TestSoftmaxPrimitives:
) )
@pytest.mark.parametrize("scale_factor", [0.125]) @pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"softmax_type", "softmax_fusion_type",
[ [
pytest.param(SoftmaxType.SCALED, id="SCALED"), pytest.param(SoftmaxFusionType.SCALED, id="SCALED"),
pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"), pytest.param(SoftmaxFusionType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"), pytest.param(SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -263,11 +266,11 @@ class TestSoftmaxModule: ...@@ -263,11 +266,11 @@ class TestSoftmaxModule:
""" """
@staticmethod @staticmethod
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype): def test_forward(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype):
""" """
Test forward with parameterized configs Test forward with parameterized configs
""" """
module_runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) module_runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype)
bias = None bias = None
runner = SoftmaxModuleRunner(module_runner, bias) runner = SoftmaxModuleRunner(module_runner, bias)
runner.test_forward() runner.test_forward()
...@@ -21,6 +21,7 @@ from jax import random as jax_random ...@@ -21,6 +21,7 @@ from jax import random as jax_random
import pytest import pytest
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
AttnSoftmaxType,
canonicalize_attn_mask_type, canonicalize_attn_mask_type,
make_swa_mask, make_swa_mask,
) )
...@@ -162,6 +163,7 @@ class DotProductAttention(nn.Module): ...@@ -162,6 +163,7 @@ class DotProductAttention(nn.Module):
dropout_rate: float = 0.0 dropout_rate: float = 0.0
dtype: DType = jnp.float32 dtype: DType = jnp.float32
float32_logits: bool = False float32_logits: bool = False
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
"""Computes dot-product attention given query, key, and value. """Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on This is the core function for applying attention based on
...@@ -211,6 +213,24 @@ class DotProductAttention(nn.Module): ...@@ -211,6 +213,24 @@ class DotProductAttention(nn.Module):
assert key.shape[-2] == value.shape[-2], "k, v num_heads must match." assert key.shape[-2] == value.shape[-2], "k, v num_heads must match."
assert query.shape[-1] == key.shape[-1], "q, k head_dim must match." assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
# Infer number of attention heads from query shape
# query shape: [..., h, d] where h is num_attention_heads
num_attention_heads = query.shape[-2]
# Initialize softmax_offset for off-by-one or learnable softmax
softmax_offset = None
if self.softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
# For off-by-one softmax, use zeros with shape (1, h, 1, 1)
softmax_offset = jnp.zeros((1, num_attention_heads, 1, 1), dtype=input_dtype)
elif self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
# For learnable softmax, create a learnable parameter with shape (1, h, 1, 1)
softmax_offset = self.param(
"softmax_offset",
nn.initializers.zeros,
(1, num_attention_heads, 1, 1),
jnp.float32,
)
if self.scale_attn_logits: if self.scale_attn_logits:
head_dim = query.shape[-1] head_dim = query.shape[-1]
depth_scaling = jnp.sqrt(head_dim).astype(input_dtype) depth_scaling = jnp.sqrt(head_dim).astype(input_dtype)
...@@ -241,9 +261,23 @@ class DotProductAttention(nn.Module): ...@@ -241,9 +261,23 @@ class DotProductAttention(nn.Module):
if bias is not None: if bias is not None:
attn_weights = attn_weights + bias.astype(attn_weights.dtype) attn_weights = attn_weights + bias.astype(attn_weights.dtype)
# Add attention sink to the last column if not vanilla softmax
if self.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
# Add extra column with softmax_offset
# softmax_offset shape: (1, h, 1, 1), attn_weights shape: [b, h, q, k]
extra_col = jnp.broadcast_to(
softmax_offset,
(attn_weights.shape[0], attn_weights.shape[1], attn_weights.shape[2], 1),
)
attn_weights = jnp.concatenate([attn_weights, extra_col], axis=-1)
# Normalize the attention weights across `kv_length` dimension. # Normalize the attention weights across `kv_length` dimension.
attn_weights = jax_nn.softmax(attn_weights).astype(input_dtype) attn_weights = jax_nn.softmax(attn_weights).astype(input_dtype)
# Remove the extra column after softmax if not vanilla softmax
if self.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
attn_weights = attn_weights[..., :-1]
# Apply attention dropout. # Apply attention dropout.
if not deterministic and self.dropout_rate > 0.0: if not deterministic and self.dropout_rate > 0.0:
keep_prob = 1.0 - self.dropout_rate keep_prob = 1.0 - self.dropout_rate
...@@ -535,6 +569,7 @@ class MultiHeadAttention(nn.Module): ...@@ -535,6 +569,7 @@ class MultiHeadAttention(nn.Module):
rotary_pos_emb_group_method: str = "consecutive" rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv: bool = True fuse_qkv: bool = True
use_bias: bool = False use_bias: bool = False
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -801,6 +836,7 @@ class MultiHeadAttention(nn.Module): ...@@ -801,6 +836,7 @@ class MultiHeadAttention(nn.Module):
dropout_rate=self.dropout_rate, dropout_rate=self.dropout_rate,
dtype=self.dtype, dtype=self.dtype,
float32_logits=self.float32_logits, float32_logits=self.float32_logits,
softmax_type=self.softmax_type,
)(query, key, value, bias=attention_bias, deterministic=deterministic) )(query, key, value, bias=attention_bias, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
...@@ -1058,6 +1094,7 @@ class EncoderLayer(nn.Module): ...@@ -1058,6 +1094,7 @@ class EncoderLayer(nn.Module):
self_attn_bias_type: Any = None self_attn_bias_type: Any = None
self_attn_mask_type: str = "no_mask" self_attn_mask_type: str = "no_mask"
window_size: Tuple[int, int] = (-1, -1) window_size: Tuple[int, int] = (-1, -1)
softmax_type: str = "vanilla"
def __post_init__(self): def __post_init__(self):
if self.num_gqa_groups is None: if self.num_gqa_groups is None:
...@@ -1111,6 +1148,9 @@ class EncoderLayer(nn.Module): ...@@ -1111,6 +1148,9 @@ class EncoderLayer(nn.Module):
else: else:
x = inputs x = inputs
# Convert softmax_type string to AttnSoftmaxType enum
attn_softmax_type = AttnSoftmaxType.from_str(self.softmax_type)
# [batch, length, emb_dim] -> [batch, length, emb_dim] # [batch, length, emb_dim] -> [batch, length, emb_dim]
x = MultiHeadAttention( x = MultiHeadAttention(
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
...@@ -1126,6 +1166,7 @@ class EncoderLayer(nn.Module): ...@@ -1126,6 +1166,7 @@ class EncoderLayer(nn.Module):
enable_rotary_pos_emb=self.enable_rotary_pos_emb, enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
use_bias=self.use_bias, use_bias=self.use_bias,
softmax_type=attn_softmax_type,
name="attention", name="attention",
)(x, x, encoder_mask, encoder_bias, deterministic=deterministic) )(x, x, encoder_mask, encoder_bias, deterministic=deterministic)
x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)( x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
...@@ -1222,6 +1263,7 @@ class DecoderLayer(nn.Module): ...@@ -1222,6 +1263,7 @@ class DecoderLayer(nn.Module):
self_attn_bias_type: Any = None self_attn_bias_type: Any = None
self_attn_mask_type: str = "no_mask" self_attn_mask_type: str = "no_mask"
window_size: Tuple[int, int] = (-1, -1) window_size: Tuple[int, int] = (-1, -1)
softmax_type: str = "vanilla"
def __post_init__(self): def __post_init__(self):
if self.num_gqa_groups is None: if self.num_gqa_groups is None:
...@@ -1290,6 +1332,9 @@ class DecoderLayer(nn.Module): ...@@ -1290,6 +1332,9 @@ class DecoderLayer(nn.Module):
else: else:
x = inputs x = inputs
# Convert softmax_type string to AttnSoftmaxType enum
attn_softmax_type = AttnSoftmaxType.from_str(self.softmax_type)
# Self-attention block # Self-attention block
x = MultiHeadAttention( x = MultiHeadAttention(
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
...@@ -1305,6 +1350,7 @@ class DecoderLayer(nn.Module): ...@@ -1305,6 +1350,7 @@ class DecoderLayer(nn.Module):
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params, fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias, use_bias=self.use_bias,
softmax_type=attn_softmax_type,
name="self_attention", name="self_attention",
)(x, x, decoder_mask, decoder_bias, deterministic=deterministic, decode=decode) )(x, x, decoder_mask, decoder_bias, deterministic=deterministic, decode=decode)
x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)( x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
...@@ -1343,6 +1389,7 @@ class DecoderLayer(nn.Module): ...@@ -1343,6 +1389,7 @@ class DecoderLayer(nn.Module):
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params, fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias, use_bias=self.use_bias,
softmax_type=attn_softmax_type,
name="encoder_decoder_attention", name="encoder_decoder_attention",
)(y, encoded, encoder_decoder_mask, deterministic=deterministic) )(y, encoded, encoder_decoder_mask, deterministic=deterministic)
y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)( y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
......
...@@ -18,6 +18,7 @@ from transformer_engine_jax import NVTE_Mask_Type ...@@ -18,6 +18,7 @@ from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_QKV_Format from transformer_engine_jax import NVTE_QKV_Format
from transformer_engine_jax import nvte_get_qkv_format from transformer_engine_jax import nvte_get_qkv_format
from transformer_engine_jax import NVTE_Softmax_Type
from . import cpp_extensions as tex from . import cpp_extensions as tex
...@@ -74,6 +75,35 @@ class AttnMaskType(Enum): ...@@ -74,6 +75,35 @@ class AttnMaskType(Enum):
] ]
class AttnSoftmaxType(Enum):
"""
VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)),
LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [H].
"""
VANILLA_SOFTMAX = NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX
OFF_BY_ONE_SOFTMAX = NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX
LEARNABLE_SOFTMAX = NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX
@classmethod
def from_str(cls, softmax_type: str) -> "AttnSoftmaxType":
"""Convert string to AttnSoftmaxType: 'vanilla', 'off_by_one', or 'learnable'."""
softmax_type_map = {
"vanilla": cls.VANILLA_SOFTMAX,
"off_by_one": cls.OFF_BY_ONE_SOFTMAX,
"learnable": cls.LEARNABLE_SOFTMAX,
}
result = softmax_type_map.get(softmax_type)
if result is None:
raise ValueError(
f"Unknown softmax_type: {softmax_type}. "
"Valid options: 'vanilla', 'off_by_one', 'learnable'"
)
return result
class QKVFormat(Enum): class QKVFormat(Enum):
""" """
SBHD: q,k,v memory layout with [s, b, ..., h, d] SBHD: q,k,v memory layout with [s, b, ..., h, d]
...@@ -301,6 +331,7 @@ def is_fused_attn_kernel_available( ...@@ -301,6 +331,7 @@ def is_fused_attn_kernel_available(
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_probability, dropout_probability,
q_num_heads, q_num_heads,
kv_num_heads, kv_num_heads,
...@@ -313,6 +344,7 @@ def is_fused_attn_kernel_available( ...@@ -313,6 +344,7 @@ def is_fused_attn_kernel_available(
""" """
To check whether the fused attention kernel is supported To check whether the fused attention kernel is supported
""" """
window_size_tuple = (-1, -1) if window_size is None else window_size
def make_helper(attn_mask_type): def make_helper(attn_mask_type):
return tex.FusedAttnHelper( return tex.FusedAttnHelper(
...@@ -322,6 +354,7 @@ def is_fused_attn_kernel_available( ...@@ -322,6 +354,7 @@ def is_fused_attn_kernel_available(
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_probability, dropout_probability,
q_num_heads, q_num_heads,
kv_num_heads, kv_num_heads,
...@@ -329,7 +362,7 @@ def is_fused_attn_kernel_available( ...@@ -329,7 +362,7 @@ def is_fused_attn_kernel_available(
kv_max_seqlen, kv_max_seqlen,
head_dim_qk, head_dim_qk,
head_dim_v, head_dim_v,
(-1, -1) if window_size is None else window_size, window_size_tuple,
) )
return make_helper(attn_mask_type).is_fused_attn_kernel_available() return make_helper(attn_mask_type).is_fused_attn_kernel_available()
...@@ -786,6 +819,7 @@ def _legacy_fused_attn( ...@@ -786,6 +819,7 @@ def _legacy_fused_attn(
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout, qkv_layout: QKVLayout,
softmax_type: AttnSoftmaxType,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
...@@ -793,6 +827,7 @@ def _legacy_fused_attn( ...@@ -793,6 +827,7 @@ def _legacy_fused_attn(
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False, context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "", context_parallel_axis: str = "",
softmax_offset: Optional[jnp.ndarray] = None,
): ):
""" """
Perform non-THD (non-packed) cuDNN fused attention. Perform non-THD (non-packed) cuDNN fused attention.
...@@ -815,6 +850,7 @@ def _legacy_fused_attn( ...@@ -815,6 +850,7 @@ def _legacy_fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout. seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias. attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask. attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors. qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores. scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention. dropout_probability (float): Dropout probability to apply during attention.
...@@ -863,10 +899,12 @@ def _legacy_fused_attn( ...@@ -863,10 +899,12 @@ def _legacy_fused_attn(
output = _fused_attn( output = _fused_attn(
qkv, qkv,
bias, bias,
softmax_offset,
SequenceDescriptor.from_seqlens((q_seq_lens, kv_seq_lens)), SequenceDescriptor.from_seqlens((q_seq_lens, kv_seq_lens)),
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
...@@ -900,6 +938,7 @@ def fused_attn_thd( ...@@ -900,6 +938,7 @@ def fused_attn_thd(
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False, context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "", context_parallel_axis: str = "",
softmax_offset: Optional[jnp.ndarray] = None,
): ):
""" """
Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor
...@@ -937,6 +976,7 @@ def fused_attn_thd( ...@@ -937,6 +976,7 @@ def fused_attn_thd(
output = _fused_attn( output = _fused_attn(
qkv, qkv,
bias, bias,
softmax_offset,
SequenceDescriptor.from_seqlens_and_offsets( SequenceDescriptor.from_seqlens_and_offsets(
(q_seq_lens, kv_seq_lens), (q_seq_offsets, kv_seq_offsets) (q_seq_lens, kv_seq_lens), (q_seq_offsets, kv_seq_offsets)
), ),
...@@ -945,6 +985,7 @@ def fused_attn_thd( ...@@ -945,6 +985,7 @@ def fused_attn_thd(
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
softmax_type=AttnSoftmaxType.VANILLA_SOFTMAX,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq, max_segments_per_seq=max_segments_per_seq,
...@@ -957,15 +998,17 @@ def fused_attn_thd( ...@@ -957,15 +998,17 @@ def fused_attn_thd(
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)) @partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _fused_attn( def _fused_attn(
qkv: Tuple[jnp.ndarray, ...], qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray], bias: Optional[jnp.ndarray],
softmax_offset: Optional[jnp.ndarray],
sequence_descriptor: SequenceDescriptor, sequence_descriptor: SequenceDescriptor,
seed: Optional[jnp.ndarray], seed: Optional[jnp.ndarray],
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout, qkv_layout: QKVLayout,
softmax_type: AttnSoftmaxType,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
...@@ -979,11 +1022,13 @@ def _fused_attn( ...@@ -979,11 +1022,13 @@ def _fused_attn(
output, _ = _fused_attn_fwd_rule( output, _ = _fused_attn_fwd_rule(
qkv, qkv,
bias, bias,
softmax_offset,
sequence_descriptor, sequence_descriptor,
seed, seed,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
qkv_layout, qkv_layout,
softmax_type,
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
...@@ -1000,11 +1045,13 @@ def _fused_attn( ...@@ -1000,11 +1045,13 @@ def _fused_attn(
def _fused_attn_fwd_rule( def _fused_attn_fwd_rule(
qkv, qkv,
bias, bias,
softmax_offset,
sequence_descriptor, sequence_descriptor,
seed, seed,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
qkv_layout, qkv_layout,
softmax_type,
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
...@@ -1018,10 +1065,12 @@ def _fused_attn_fwd_rule( ...@@ -1018,10 +1065,12 @@ def _fused_attn_fwd_rule(
output, softmax_aux, rng_state = tex.fused_attn_fwd( output, softmax_aux, rng_state = tex.fused_attn_fwd(
qkv, qkv,
bias, bias,
softmax_offset,
sequence_descriptor, sequence_descriptor,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
...@@ -1041,6 +1090,7 @@ def _fused_attn_fwd_rule( ...@@ -1041,6 +1090,7 @@ def _fused_attn_fwd_rule(
sequence_descriptor, sequence_descriptor,
softmax_aux, softmax_aux,
rng_state, rng_state,
softmax_offset,
output, output,
) )
...@@ -1049,6 +1099,7 @@ def _fused_attn_bwd_rule( ...@@ -1049,6 +1099,7 @@ def _fused_attn_bwd_rule(
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
qkv_layout, qkv_layout,
softmax_type,
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
...@@ -1068,11 +1119,13 @@ def _fused_attn_bwd_rule( ...@@ -1068,11 +1119,13 @@ def _fused_attn_bwd_rule(
sequence_descriptor, sequence_descriptor,
softmax_aux, softmax_aux,
rng_state, rng_state,
softmax_offset,
output, output,
) = ctx ) = ctx
grad_qkv, grad_bias = tex.fused_attn_bwd( grad_qkv, grad_bias, grad_softmax_offset = tex.fused_attn_bwd(
qkv, qkv,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1080,6 +1133,7 @@ def _fused_attn_bwd_rule( ...@@ -1080,6 +1133,7 @@ def _fused_attn_bwd_rule(
sequence_descriptor, sequence_descriptor,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
...@@ -1092,9 +1146,12 @@ def _fused_attn_bwd_rule( ...@@ -1092,9 +1146,12 @@ def _fused_attn_bwd_rule(
) )
if attn_bias_type == AttnBiasType.NO_BIAS: if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None grad_bias = None
if softmax_type != AttnSoftmaxType.LEARNABLE_SOFTMAX:
grad_softmax_offset = None
return ( return (
grad_qkv, grad_qkv,
grad_bias, grad_bias,
grad_softmax_offset,
None, None,
None, None,
) )
...@@ -1111,6 +1168,7 @@ def fused_attn( ...@@ -1111,6 +1168,7 @@ def fused_attn(
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout, qkv_layout: QKVLayout,
softmax_type: AttnSoftmaxType,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
...@@ -1120,6 +1178,7 @@ def fused_attn( ...@@ -1120,6 +1178,7 @@ def fused_attn(
context_parallel_causal_load_balanced: bool = False, context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "", context_parallel_axis: str = "",
context_checkpoint_name: str = "context", context_checkpoint_name: str = "context",
softmax_offset: Optional[jnp.ndarray] = None,
): ):
""" """
Perform cuDNN fused attention. Perform cuDNN fused attention.
...@@ -1139,6 +1198,7 @@ def fused_attn( ...@@ -1139,6 +1198,7 @@ def fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout. seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias. attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask. attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors. qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores. scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention. dropout_probability (float): Dropout probability to apply during attention.
...@@ -1153,6 +1213,9 @@ def fused_attn( ...@@ -1153,6 +1213,9 @@ def fused_attn(
Indicates the sequences are ordered for causal mask load balancing when running context parallelism. Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis. context_parallel_axis (str): The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass. context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape
[1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX.
If provided, this parameter will receive gradients during backpropagation.
Returns: Returns:
(jnp.ndarray): The output tensor from the fused attention. (jnp.ndarray): The output tensor from the fused attention.
...@@ -1200,6 +1263,7 @@ def fused_attn( ...@@ -1200,6 +1263,7 @@ def fused_attn(
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
...@@ -1208,15 +1272,18 @@ def fused_attn( ...@@ -1208,15 +1272,18 @@ def fused_attn(
context_parallel_strategy=context_parallel_strategy, context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis, context_parallel_axis=context_parallel_axis,
softmax_offset=softmax_offset,
) )
output = _fused_attn( output = _fused_attn(
qkv, qkv,
bias, bias,
softmax_offset,
sequence_descriptor, sequence_descriptor,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
softmax_type=softmax_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
......
...@@ -20,11 +20,13 @@ from transformer_engine_jax import NVTE_Fused_Attn_Backend ...@@ -20,11 +20,13 @@ from transformer_engine_jax import NVTE_Fused_Attn_Backend
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
AttnSoftmaxType,
QKVLayout, QKVLayout,
QKVFormat, QKVFormat,
CPStrategy, CPStrategy,
SequenceDescriptor, SequenceDescriptor,
) )
from ..sharding import with_sharding_constraint_by_logical_axes, HEAD_AXES
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .misc import ( from .misc import (
...@@ -61,6 +63,7 @@ __all__ = [ ...@@ -61,6 +63,7 @@ __all__ = [
meta_fields=[ meta_fields=[
"attn_bias_type", "attn_bias_type",
"attn_mask_type", "attn_mask_type",
"softmax_type",
"qkv_layout", "qkv_layout",
"scaling_factor", "scaling_factor",
"dropout_probability", "dropout_probability",
...@@ -80,6 +83,7 @@ class _FusedAttnConfig: ...@@ -80,6 +83,7 @@ class _FusedAttnConfig:
attn_bias_type: AttnBiasType attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType attn_mask_type: AttnMaskType
softmax_type: AttnSoftmaxType
qkv_layout: QKVLayout qkv_layout: QKVLayout
scaling_factor: float scaling_factor: float
dropout_probability: float dropout_probability: float
...@@ -103,6 +107,7 @@ class FusedAttnHelper: ...@@ -103,6 +107,7 @@ class FusedAttnHelper:
qkv_layout: QKVLayout qkv_layout: QKVLayout
attn_bias_type: AttnBiasType attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType attn_mask_type: AttnMaskType
softmax_type: AttnSoftmaxType
dropout_probability: float dropout_probability: float
q_num_heads: int q_num_heads: int
kv_num_heads: int kv_num_heads: int
...@@ -125,6 +130,7 @@ class FusedAttnHelper: ...@@ -125,6 +130,7 @@ class FusedAttnHelper:
self.qkv_layout.value, self.qkv_layout.value,
self.attn_bias_type.value, self.attn_bias_type.value,
self.attn_mask_type.value, self.attn_mask_type.value,
self.softmax_type.value,
self.dropout_probability, self.dropout_probability,
self.q_num_heads, self.q_num_heads,
self.kv_num_heads, self.kv_num_heads,
...@@ -254,7 +260,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -254,7 +260,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
name = "te_fused_attn_forward_ffi" name = "te_fused_attn_forward_ffi"
multiple_results = True multiple_results = True
impl_static_args = (13,) impl_static_args = (14,)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -264,6 +270,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -264,6 +270,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k_aval, k_aval,
v_aval, v_aval,
bias_aval, bias_aval,
softmax_offset_aval,
seed_aval, seed_aval,
q_seqlen_or_cu_seqlen_aval, q_seqlen_or_cu_seqlen_aval,
kv_seqlen_or_cu_seqlen_aval, kv_seqlen_or_cu_seqlen_aval,
...@@ -312,6 +319,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -312,6 +319,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
config.qkv_layout, config.qkv_layout,
config.attn_bias_type, config.attn_bias_type,
config.attn_mask_type, config.attn_mask_type,
config.softmax_type,
config.dropout_probability, config.dropout_probability,
attn_heads, attn_heads,
num_gqa_groups, num_gqa_groups,
...@@ -375,6 +383,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -375,6 +383,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
config.dropout_probability, config.dropout_probability,
config.attn_bias_type.value, config.attn_bias_type.value,
config.attn_mask_type.value, config.attn_mask_type.value,
config.softmax_type.value,
config.qkv_layout.value, config.qkv_layout.value,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(q_aval.dtype),
config.is_training, config.is_training,
...@@ -386,6 +395,12 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -386,6 +395,12 @@ class FusedAttnFwdPrimitive(BasePrimitive):
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
) )
assert softmax_offset_aval.dtype == jnp.float32
if config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
assert softmax_offset_aval.shape == (1, attn_heads, 1, 1)
else:
assert softmax_offset_aval.shape == (0,)
return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval
@staticmethod @staticmethod
...@@ -405,6 +420,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -405,6 +420,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
seed, seed,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
...@@ -453,6 +469,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -453,6 +469,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
seed, seed,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
...@@ -481,6 +498,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -481,6 +498,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=window_size_left, window_size_left=window_size_left,
window_size_right=window_size_right, window_size_right=window_size_right,
softmax_type=int(config.softmax_type.value),
) )
@staticmethod @staticmethod
...@@ -489,6 +507,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -489,6 +507,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
seed, seed,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
...@@ -579,6 +598,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -579,6 +598,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
seed, seed,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
...@@ -596,7 +616,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -596,7 +616,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
def batcher(batched_args, batch_dims, *, config): def batcher(batched_args, batch_dims, *, config):
check_valid_batch_dims(batch_dims) check_valid_batch_dims(batch_dims)
assert FusedAttnFwdPrimitive.outer_primitive is not None assert FusedAttnFwdPrimitive.outer_primitive is not None
q_bdim, _, _, _, seed_bdim, *_ = batch_dims q_bdim, _, _, _, _, seed_bdim, *_ = batch_dims
out_bdims = q_bdim, q_bdim, seed_bdim out_bdims = q_bdim, q_bdim, seed_bdim
return ( return (
...@@ -662,7 +682,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -662,7 +682,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
mesh, PartitionSpec(get_all_mesh_axes(), None) mesh, PartitionSpec(get_all_mesh_axes(), None)
) )
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding arg_shardings[5] = seed_sharding
arg_shardings[-1] = arg_shardings[-3] arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4] arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
...@@ -710,7 +730,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -710,7 +730,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
name = "te_fused_attn_backward_ffi" name = "te_fused_attn_backward_ffi"
multiple_results = True multiple_results = True
impl_static_args = (16,) impl_static_args = (17,)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -720,6 +740,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -720,6 +740,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k_aval, k_aval,
v_aval, v_aval,
bias_aval, bias_aval,
softmax_offset_aval,
softmax_aux_aval, softmax_aux_aval,
rng_state_aval, rng_state_aval,
output_aval, output_aval,
...@@ -781,6 +802,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -781,6 +802,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
config.dropout_probability, config.dropout_probability,
config.attn_bias_type.value, config.attn_bias_type.value,
config.attn_mask_type.value, config.attn_mask_type.value,
config.softmax_type.value,
config.qkv_layout.value, config.qkv_layout.value,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(q_aval.dtype),
config.is_training, config.is_training,
...@@ -798,15 +820,39 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -798,15 +820,39 @@ class FusedAttnBwdPrimitive(BasePrimitive):
shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype) shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype)
) )
return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval # Validate incoming softmax_offset shape and dtype
assert (
softmax_offset_aval.dtype == jnp.float32
), f"Incorrect softmax_offset dtype: {softmax_offset_aval.dtype}, expected: {jnp.float32}"
if config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
assert softmax_offset_aval.shape == (1, attn_heads, 1, 1), (
f"Incorrect softmax_offset shape for {config.softmax_type}:"
f" {softmax_offset_aval.shape}, expected: (1, {attn_heads}, 1, 1)"
)
else:
assert softmax_offset_aval.shape == (0,), (
f"Incorrect softmax_offset shape for {config.softmax_type}:"
f" {softmax_offset_aval.shape}, expected: (0,)"
)
if config.softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX:
dsoftmax_offset_aval = q_aval.update(
shape=softmax_offset_aval.shape, dtype=softmax_offset_aval.dtype
)
else:
dsoftmax_offset_aval = q_aval.update(shape=(1, attn_heads, 1, 1), dtype=jnp.float32)
return dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval, wkspace_aval
@staticmethod @staticmethod
def outer_abstract(*args, **kwargs): def outer_abstract(*args, **kwargs):
""" """
Fused attention fwd outer primitive abstract Fused attention fwd outer primitive abstract
""" """
dq_aval, dk_aval, dv_aval, dbias_aval, _ = FusedAttnBwdPrimitive.abstract(*args, **kwargs) dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval, _ = (
return dq_aval, dk_aval, dv_aval, dbias_aval FusedAttnBwdPrimitive.abstract(*args, **kwargs)
)
return dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval
@staticmethod @staticmethod
def lowering( def lowering(
...@@ -815,6 +861,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -815,6 +861,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -866,6 +913,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -866,6 +913,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -897,6 +945,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -897,6 +945,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=window_size_left, window_size_left=window_size_left,
window_size_right=window_size_right, window_size_right=window_size_right,
softmax_type=int(config.softmax_type.value),
) )
@staticmethod @staticmethod
...@@ -905,6 +954,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -905,6 +954,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -993,11 +1043,12 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -993,11 +1043,12 @@ class FusedAttnBwdPrimitive(BasePrimitive):
q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind( dq, dk, dv, dbias, dsoftmax_offset, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
q, q,
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1012,15 +1063,15 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -1012,15 +1063,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
_kv_segment_pos, _kv_segment_pos,
config=config, config=config,
) )
return dq, dk, dv, dbias return dq, dk, dv, dbias, dsoftmax_offset
@staticmethod @staticmethod
def batcher(batched_args, batch_dims, *, config): def batcher(batched_args, batch_dims, *, config):
check_valid_batch_dims(batch_dims) check_valid_batch_dims(batch_dims)
assert FusedAttnBwdPrimitive.outer_primitive is not None assert FusedAttnBwdPrimitive.outer_primitive is not None
q_bdim, k_bdim, v_bdim, *_ = batch_dims q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim, *_ = batch_dims
out_bdims = q_bdim, k_bdim, v_bdim, q_bdim out_bdims = q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim
return ( return (
FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config), FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config),
out_bdims, out_bdims,
...@@ -1033,11 +1084,13 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -1033,11 +1084,13 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k_spec = get_padded_spec(arg_infos[1]) k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2]) v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3]) bias_spec = get_padded_spec(arg_infos[3])
softmax_offset_spec = get_padded_spec(arg_infos[4])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding, dsoftmax_offset_sharding)
@staticmethod @staticmethod
def partition(config, mesh, arg_infos, result_infos): def partition(config, mesh, arg_infos, result_infos):
...@@ -1046,21 +1099,30 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -1046,21 +1099,30 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k_spec = get_padded_spec(arg_infos[1]) k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2]) v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3]) bias_spec = get_padded_spec(arg_infos[3])
softmax_offset_spec = get_padded_spec(arg_infos[4])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[-1] = arg_shardings[-3] arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4] arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) out_shardings = (
dq_sharding,
dk_sharding,
dv_sharding,
dbias_sharding,
dsoftmax_offset_sharding,
)
def sharded_impl( def sharded_impl(
q, q,
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1074,36 +1136,43 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -1074,36 +1136,43 @@ class FusedAttnBwdPrimitive(BasePrimitive):
_q_segment_pos, _q_segment_pos,
_kv_segment_pos, _kv_segment_pos,
): ):
local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl( local_dq, local_dk, local_dv, local_dbias, local_dsoftmax_offset = (
q, FusedAttnBwdPrimitive.impl(
k, q,
v, k,
bias, v,
softmax_aux, bias,
rng_state, softmax_offset,
output, softmax_aux,
doutput, rng_state,
q_cu_seqlen, output,
kv_cu_seqlen, doutput,
q_seq_offsets, q_cu_seqlen,
k_seq_offsets, kv_cu_seqlen,
_q_segment_ids, q_seq_offsets,
_kv_segment_ids, k_seq_offsets,
_q_segment_pos, _q_segment_ids,
_kv_segment_pos, _kv_segment_ids,
config=config, _q_segment_pos,
_kv_segment_pos,
config=config,
)
) )
global_dbias = local_dbias global_dbias = local_dbias
if config.attn_bias_type is not AttnBiasType.NO_BIAS: if config.attn_bias_type is not AttnBiasType.NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
return local_dq, local_dk, local_dv, global_dbias
global_dsoftmax_offset = local_dsoftmax_offset
if config.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
global_dsoftmax_offset = all_reduce_sum_along_dp_fsdp(local_dsoftmax_offset, mesh)
return local_dq, local_dk, local_dv, global_dbias, global_dsoftmax_offset
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod @staticmethod
def shardy_sharding_rule(config, mesh, value_types, result_types): def shardy_sharding_rule(config, mesh, value_types, result_types):
del config, mesh del config, mesh
# We only care about the four first arguments.
# Keep in sync with `infer_sharding_from_operands`. # Keep in sync with `infer_sharding_from_operands`.
input_spec = tuple((f"…{x}",) for x in range(len(value_types))) input_spec = tuple((f"…{x}",) for x in range(len(value_types)))
output_spec = tuple((f"…{x}",) for x in range(len(result_types))) output_spec = tuple((f"…{x}",) for x in range(len(result_types)))
...@@ -1229,6 +1298,11 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1229,6 +1298,11 @@ class _FusedAttnCPWithAllGatherHelper:
if self.config.dropout_probability != 0.0: if self.config.dropout_probability != 0.0:
raise ValueError(f"{header} does not support dropout") raise ValueError(f"{header} does not support dropout")
if self.config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
raise ValueError(
f"{header} only supports VANILLA_SOFTMAX, got: {self.config.softmax_type}"
)
def get_adjusted_mask(self): def get_adjusted_mask(self):
"""Converts the mask for context parallelism.""" """Converts the mask for context parallelism."""
if self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK: if self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
...@@ -1240,6 +1314,7 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1240,6 +1314,7 @@ class _FusedAttnCPWithAllGatherHelper:
return _FusedAttnConfig( return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type, attn_bias_type=self.config.attn_bias_type,
attn_mask_type=self.get_adjusted_mask(), attn_mask_type=self.get_adjusted_mask(),
softmax_type=self.config.softmax_type,
qkv_layout=self.config.qkv_layout, qkv_layout=self.config.qkv_layout,
scaling_factor=self.config.scaling_factor, scaling_factor=self.config.scaling_factor,
dropout_probability=self.config.dropout_probability, dropout_probability=self.config.dropout_probability,
...@@ -1376,7 +1451,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1376,7 +1451,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
mesh, PartitionSpec(get_all_mesh_axes(), None) mesh, PartitionSpec(get_all_mesh_axes(), None)
) )
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding arg_shardings[5] = seed_sharding
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
...@@ -1385,6 +1460,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1385,6 +1460,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
seed, seed,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
...@@ -1404,7 +1480,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1404,7 +1480,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
# meeting the expectation of the SPMD model. # meeting the expectation of the SPMD model.
# TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding # TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding
# mask/sequence length tensor to avoid this unrolled loop. # mask/sequence length tensor to avoid this unrolled loop.
def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed): def _cross_attn(idx, q, k, v, bias, softmax_offset, q_seqlen, kv_seqlen, seed):
kv_max_seqlen = k.shape[1] kv_max_seqlen = k.shape[1]
kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2) kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2)
assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size" assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size"
...@@ -1431,6 +1507,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1431,6 +1507,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
k_unmasked, k_unmasked,
v_unmasked, v_unmasked,
bias, bias,
softmax_offset,
seed, seed,
q_seqlen_for_step, q_seqlen_for_step,
kv_seqlen_for_step, kv_seqlen_for_step,
...@@ -1453,7 +1530,9 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1453,7 +1530,9 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
k_ag, v_ag = helper.all_gather_kv(k, v) k_ag, v_ag = helper.all_gather_kv(k, v)
functions = [ functions = [
partial(_cross_attn, idx, q, k_ag, v_ag, bias, q_seqlen, kv_seqlen, seed) partial(
_cross_attn, idx, q, k_ag, v_ag, bias, softmax_offset, q_seqlen, kv_seqlen, seed
)
for idx in range(cp_size) for idx in range(cp_size)
] ]
...@@ -1492,18 +1571,27 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1492,18 +1571,27 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
k_spec = get_padded_spec(arg_infos[1]) k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2]) v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3]) bias_spec = get_padded_spec(arg_infos[3])
softmax_offset_spec = get_padded_spec(arg_infos[4])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) out_shardings = (
dq_sharding,
dk_sharding,
dv_sharding,
dbias_sharding,
dsoftmax_offset_sharding,
)
def impl( def impl(
q, q,
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1527,6 +1615,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1527,6 +1615,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1562,11 +1651,12 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1562,11 +1651,12 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx] num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
kv_seqlen_for_step = (kv_seqlen // (cp_size * 2)) * num_kv_chunks kv_seqlen_for_step = (kv_seqlen // (cp_size * 2)) * num_kv_chunks
dq_local, dk_local, dv_local, dbias_local = FusedAttnBwdPrimitive.impl( dq_local, dk_local, dv_local, dbias_local, _ = FusedAttnBwdPrimitive.impl(
q_split[sub_idx], q_split[sub_idx],
k_unmasked, k_unmasked,
v_unmasked, v_unmasked,
bias, bias,
softmax_offset,
softmax_aux_split[sub_idx], softmax_aux_split[sub_idx],
rng_state, rng_state,
output_split[sub_idx], output_split[sub_idx],
...@@ -1604,6 +1694,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1604,6 +1694,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
k_ag, k_ag,
v_ag, v_ag,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1621,7 +1712,9 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1621,7 +1712,9 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions) dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions)
dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local) dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local)
return dq, dk, dv, dbias # Return dummy dsoftmax_offset for arity matching (all-gather CP doesn't use it)
dummy_dsoftmax_offset = jnp.empty_like(softmax_offset)
return dq, dk, dv, dbias, dummy_dsoftmax_offset
return mesh, impl, out_shardings, arg_shardings return mesh, impl, out_shardings, arg_shardings
...@@ -1679,6 +1772,11 @@ class _FusedAttnCPWithP2PHelper: ...@@ -1679,6 +1772,11 @@ class _FusedAttnCPWithP2PHelper:
if self.config.dropout_probability != 0.0: if self.config.dropout_probability != 0.0:
raise ValueError(f"{header} does not support dropout") raise ValueError(f"{header} does not support dropout")
if self.config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
raise ValueError(
f"{header} only supports VANILLA_SOFTMAX, got: {self.config.softmax_type}"
)
# We want to encourage use of scan loop to minimize unrolling and ensure more # We want to encourage use of scan loop to minimize unrolling and ensure more
# predictable scheduling from XLA. The unrolled flavor will be supported but # predictable scheduling from XLA. The unrolled flavor will be supported but
# not the prefered implementation. # not the prefered implementation.
...@@ -1703,6 +1801,7 @@ class _FusedAttnCPWithP2PHelper: ...@@ -1703,6 +1801,7 @@ class _FusedAttnCPWithP2PHelper:
return _FusedAttnConfig( return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type, attn_bias_type=self.config.attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
softmax_type=self.config.softmax_type,
qkv_layout=QKVLayout.BSHD_BS2HD, qkv_layout=QKVLayout.BSHD_BS2HD,
scaling_factor=self.config.scaling_factor, scaling_factor=self.config.scaling_factor,
dropout_probability=self.config.dropout_probability, dropout_probability=self.config.dropout_probability,
...@@ -1783,7 +1882,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1783,7 +1882,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
mesh, PartitionSpec(get_all_mesh_axes(), None) mesh, PartitionSpec(get_all_mesh_axes(), None)
) )
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding arg_shardings[5] = seed_sharding
# Ensure segment_pos gets same sharding as ID. # Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3] arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4] arg_shardings[-2] = arg_shardings[-4]
...@@ -1795,6 +1894,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1795,6 +1894,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
k, k,
v, v,
bias, bias,
_softmax_offset,
seed, seed,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
...@@ -1840,6 +1940,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1840,6 +1940,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv, kv,
_not_used, _not_used,
bias, bias,
_softmax_offset,
seed, seed,
q_seqlen_per_step, q_seqlen_per_step,
kv_seqlen_per_step, kv_seqlen_per_step,
...@@ -1865,6 +1966,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1865,6 +1966,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv_part, kv_part,
_not_used, _not_used,
bias, bias,
_softmax_offset,
seed, seed,
q_seqlen_per_step, q_seqlen_per_step,
kv_seqlen_per_step, kv_seqlen_per_step,
...@@ -1887,6 +1989,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1887,6 +1989,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv, kv,
_not_used, _not_used,
bias, bias,
_softmax_offset,
seed, seed,
q_seqlen_per_step, q_seqlen_per_step,
kv_seqlen_per_step, kv_seqlen_per_step,
...@@ -1990,18 +2093,24 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1990,18 +2093,24 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
k_spec = get_padded_spec(arg_infos[1]) k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2]) v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3]) bias_spec = get_padded_spec(arg_infos[3])
softmax_offset_spec = get_padded_spec(arg_infos[4])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
# Ring attention doesn't use dsoftmax_offset, but we need to return it for arity matching
dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
# Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3] arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4] arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
out_shardings = (
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) dq_sharding,
dk_sharding,
dv_sharding,
dbias_sharding,
dsoftmax_offset_sharding,
)
helper = _FusedAttnCPWithP2PHelper(mesh, config) helper = _FusedAttnCPWithP2PHelper(mesh, config)
helper.check_supported() helper.check_supported()
...@@ -2011,6 +2120,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2011,6 +2120,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
k, k,
v, v,
bias, bias,
_softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2054,11 +2164,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2054,11 +2164,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
def mask_compute(attn_mask_type): def mask_compute(attn_mask_type):
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( dq_per_step, dk_dv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
q, q,
kv, kv,
_not_used, _not_used,
bias, bias,
_softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2082,11 +2193,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2082,11 +2193,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2 kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2
kv_part = lax.slice_in_dim(kv, 0, kv_max_seqlen // 2, axis=1) kv_part = lax.slice_in_dim(kv, 0, kv_max_seqlen // 2, axis=1)
dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( dq_per_step, dk_dv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
q, q,
kv_part, kv_part,
_not_used, _not_used,
bias, bias,
_softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2120,11 +2232,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2120,11 +2232,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
softmax_aux, q_max_seqlen // 2, q_max_seqlen, axis=2 softmax_aux, q_max_seqlen // 2, q_max_seqlen, axis=2
) )
dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( dq_per_step, dk_dv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
q_part, q_part,
kv, kv,
_not_used, _not_used,
bias, bias,
_softmax_offset,
softmax_aux_part, softmax_aux_part,
rng_state, rng_state,
output_part, output_part,
...@@ -2184,7 +2297,9 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2184,7 +2297,9 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh) global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)
dk, dv = helper.unstack_kv(dk_dv) dk, dv = helper.unstack_kv(dk_dv)
return dq, dk, dv, global_dbias # Return dummy dsoftmax_offset for arity matching (ring attention doesn't use it)
dummy_dsoftmax_offset = jnp.empty_like(_softmax_offset)
return dq, dk, dv, global_dbias, dummy_dsoftmax_offset
return mesh, ring_attn_bwd_impl, out_shardings, arg_shardings return mesh, ring_attn_bwd_impl, out_shardings, arg_shardings
...@@ -2273,7 +2388,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2273,7 +2388,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
mesh, PartitionSpec(get_all_mesh_axes(), None) mesh, PartitionSpec(get_all_mesh_axes(), None)
) )
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding arg_shardings[5] = seed_sharding
# Ensure segment_pos gets same sharding as ID. # Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3] arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4] arg_shardings[-2] = arg_shardings[-4]
...@@ -2285,6 +2400,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2285,6 +2400,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
k, k,
v, v,
bias, bias,
_softmax_offset,
seed, seed,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
...@@ -2336,6 +2452,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2336,6 +2452,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
kv, kv,
_not_used, _not_used,
bias, bias,
_softmax_offset,
seed, seed,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
...@@ -2345,7 +2462,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2345,7 +2462,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
kv_segment_ids, kv_segment_ids,
q_segment_pos, q_segment_pos,
kv_segment_pos, kv_segment_pos,
config, config=config,
) )
if config.window_size != (-1, -1): if config.window_size != (-1, -1):
...@@ -2420,8 +2537,8 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2420,8 +2537,8 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
arg_shardings[-1] = arg_shardings[-3] arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4] arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
# dq, dk, dv, dbias sharding = q, k, v, bias sharding # dq, dk, dv, dbias, dsoftmax_offset sharding = q, k, v, bias, softmax_offset sharding
out_shardings = tuple(arg.sharding for arg in arg_infos[:4]) out_shardings = tuple(arg.sharding for arg in arg_infos[:5])
helper = _FusedAttnCPWithP2PHelper(mesh, config) helper = _FusedAttnCPWithP2PHelper(mesh, config)
helper.check_supported() helper.check_supported()
...@@ -2431,6 +2548,7 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2431,6 +2548,7 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
k, k,
v, v,
bias, bias,
_softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2478,11 +2596,12 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2478,11 +2596,12 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm) kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)
def compute(config): def compute(config):
dq_per_step, dkv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( dq_per_step, dkv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
q, q,
kv, kv,
_not_used, _not_used,
bias, bias,
_softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2536,7 +2655,9 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2536,7 +2655,9 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh) global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)
dk, dv = helper.unstack_kv(dkv) dk, dv = helper.unstack_kv(dkv)
return dq, dk, dv, global_dbias # Return dummy dsoftmax_offset for arity matching (ring attention doesn't use it)
dummy_dsoftmax_offset = jnp.empty_like(_softmax_offset)
return dq, dk, dv, global_dbias, dummy_dsoftmax_offset
return mesh, bwd_impl, out_shardings, arg_shardings return mesh, bwd_impl, out_shardings, arg_shardings
...@@ -2557,10 +2678,12 @@ def _maybe_context_parallel_axis(cp_axis: str): ...@@ -2557,10 +2678,12 @@ def _maybe_context_parallel_axis(cp_axis: str):
def fused_attn_fwd( def fused_attn_fwd(
qkv: Tuple[jnp.ndarray, ...], qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray], bias: Optional[jnp.ndarray],
softmax_offset: Optional[jnp.ndarray],
sequence_descriptor: SequenceDescriptor, sequence_descriptor: SequenceDescriptor,
seed: Optional[jnp.ndarray], seed: Optional[jnp.ndarray],
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
softmax_type: AttnSoftmaxType,
qkv_layout: QKVLayout, qkv_layout: QKVLayout,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
...@@ -2585,6 +2708,7 @@ def fused_attn_fwd( ...@@ -2585,6 +2708,7 @@ def fused_attn_fwd(
query has a different shape (e.g., cross-attention). query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors. - `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores. bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
softmax_offset (Optional[jnp.ndarray]): An optional softmax offset tensor.
q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,]. q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,]. kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
q_seq_offsets (jnp.ndarray): q_seq_offsets (jnp.ndarray):
...@@ -2594,6 +2718,7 @@ def fused_attn_fwd( ...@@ -2594,6 +2718,7 @@ def fused_attn_fwd(
seed (Optional[jnp.ndarray]): Optional random seed for dropout. seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias. attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask. attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors. qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores. scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention. dropout_probability (float): Dropout probability to apply during attention.
...@@ -2633,10 +2758,36 @@ def fused_attn_fwd( ...@@ -2633,10 +2758,36 @@ def fused_attn_fwd(
assert bias is None assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype) bias = jnp.zeros(0, dtype=qkv[0].dtype)
if softmax_offset is None:
assert (
softmax_type != AttnSoftmaxType.LEARNABLE_SOFTMAX
), f"Softmax type {softmax_type} is not supported when softmax_offset is None"
if softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
num_heads = qkv[0].shape[-2]
# Create tensor [1, h, 1, 1] filled with zeros (logit value = 0)
# This adds exp(0 - x_max) = exp(-x_max) to the denominator,
# which contributes exactly 1 after normalization, giving: exp(x_i) / (sum(exp(x_j)) + 1)
softmax_offset = jnp.zeros((1, num_heads, 1, 1), dtype=jnp.float32)
# Shard by heads dimension
softmax_offset = with_sharding_constraint_by_logical_axes(
softmax_offset, (None, HEAD_AXES, None, None)
)
else:
assert softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX
softmax_offset = jnp.zeros(0, dtype=jnp.float32)
else:
assert softmax_offset.dtype == jnp.float32
# Shard by heads dimension if not VANILLA_SOFTMAX
if softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
softmax_offset = with_sharding_constraint_by_logical_axes(
softmax_offset, (None, HEAD_AXES, None, None)
)
fused_config = _FusedAttnConfig( fused_config = _FusedAttnConfig(
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
softmax_type=softmax_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
...@@ -2662,6 +2813,7 @@ def fused_attn_fwd( ...@@ -2662,6 +2813,7 @@ def fused_attn_fwd(
output, softmax_aux, rng_state = primitive.bind( output, softmax_aux, rng_state = primitive.bind(
*qkv_for_primitive, *qkv_for_primitive,
bias, bias,
softmax_offset,
seed, seed,
*seq_desc_flatten, *seq_desc_flatten,
config=fused_config, config=fused_config,
...@@ -2673,6 +2825,7 @@ def fused_attn_fwd( ...@@ -2673,6 +2825,7 @@ def fused_attn_fwd(
def fused_attn_bwd( def fused_attn_bwd(
qkv: Tuple[jnp.ndarray, ...], qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray], bias: Optional[jnp.ndarray],
softmax_offset: Optional[jnp.ndarray],
softmax_aux: jnp.ndarray, softmax_aux: jnp.ndarray,
rng_state: jnp.ndarray, rng_state: jnp.ndarray,
output: jnp.ndarray, output: jnp.ndarray,
...@@ -2681,6 +2834,7 @@ def fused_attn_bwd( ...@@ -2681,6 +2834,7 @@ def fused_attn_bwd(
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout, qkv_layout: QKVLayout,
softmax_type: AttnSoftmaxType,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
...@@ -2702,6 +2856,7 @@ def fused_attn_bwd( ...@@ -2702,6 +2856,7 @@ def fused_attn_bwd(
query has a different shape (e.g., cross-attention). query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors. - `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores. bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
softmax_offset (Optional[jnp.ndarray]): An optional softmax offset tensor.
softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass. softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass.
rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass. rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass.
output (jnp.ndarray): The output tensor from the forward pass. output (jnp.ndarray): The output tensor from the forward pass.
...@@ -2714,6 +2869,7 @@ def fused_attn_bwd( ...@@ -2714,6 +2869,7 @@ def fused_attn_bwd(
The offsets in the sequence dim for the query, with shape [batch + 1,]. The offsets in the sequence dim for the query, with shape [batch + 1,].
attn_bias_type (AttnBiasType): Type of attention bias. attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask. attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors. qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores. scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention. dropout_probability (float): Dropout probability to apply during attention.
...@@ -2755,6 +2911,28 @@ def fused_attn_bwd( ...@@ -2755,6 +2911,28 @@ def fused_attn_bwd(
assert bias is None assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype) bias = jnp.zeros(0, dtype=qkv[0].dtype)
if softmax_offset is None:
assert softmax_type != AttnSoftmaxType.LEARNABLE_SOFTMAX, f"Unknown {softmax_type=}"
if softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
num_heads = qkv[0].shape[-2]
# Create tensor [1, h, 1, 1] filled with zeros
softmax_offset = jnp.zeros((1, num_heads, 1, 1), dtype=jnp.float32)
# Shard by heads dimension
softmax_offset = with_sharding_constraint_by_logical_axes(
softmax_offset, (None, HEAD_AXES, None, None)
)
elif softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX:
softmax_offset = jnp.zeros(0, dtype=jnp.float32)
else:
raise NotImplementedError(f"Unknown {softmax_type=}")
else:
softmax_offset = softmax_offset.astype(jnp.float32)
# Shard by heads dimension if not VANILLA_SOFTMAX
if softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
softmax_offset = with_sharding_constraint_by_logical_axes(
softmax_offset, (None, HEAD_AXES, None, None)
)
# TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on # TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on
# sm100+ # sm100+
compute_capabilities = get_all_device_compute_capability() compute_capabilities = get_all_device_compute_capability()
...@@ -2767,6 +2945,7 @@ def fused_attn_bwd( ...@@ -2767,6 +2945,7 @@ def fused_attn_bwd(
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
softmax_type=softmax_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
...@@ -2788,9 +2967,10 @@ def fused_attn_bwd( ...@@ -2788,9 +2967,10 @@ def fused_attn_bwd(
primitive = FusedRingAttnBwdPrimitive.outer_primitive primitive = FusedRingAttnBwdPrimitive.outer_primitive
seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
*qkv_grads, bias_grad = primitive.bind( *qkv_grads, bias_grad, softmax_offset_grad = primitive.bind(
*qkv_for_primitive, *qkv_for_primitive,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2798,4 +2978,4 @@ def fused_attn_bwd( ...@@ -2798,4 +2978,4 @@ def fused_attn_bwd(
*seq_desc_flatten, *seq_desc_flatten,
config=fused_config, config=fused_config,
) )
return tuple(qkv_grads[: len(qkv)]), bias_grad return tuple(qkv_grads[: len(qkv)]), bias_grad, softmax_offset_grad
...@@ -11,10 +11,11 @@ import jax ...@@ -11,10 +11,11 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes, ffi from jax import dtypes, ffi
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from .attention import AttnSoftmaxType
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .misc import get_padded_spec, check_valid_batch_dims from .misc import get_padded_spec, check_valid_batch_dims
from ..softmax import SoftmaxType from ..softmax import SoftmaxFusionType
__all__ = [ __all__ = [
...@@ -32,7 +33,8 @@ __all__ = [ ...@@ -32,7 +33,8 @@ __all__ = [
def is_softmax_kernel_available( def is_softmax_kernel_available(
softmax_type: SoftmaxType, softmax_fusion_type: SoftmaxFusionType,
softmax_type: AttnSoftmaxType,
batch: int, batch: int,
heads: int, heads: int,
q_seqlen: int, q_seqlen: int,
...@@ -40,15 +42,18 @@ def is_softmax_kernel_available( ...@@ -40,15 +42,18 @@ def is_softmax_kernel_available(
dtype: jnp.dtype, dtype: jnp.dtype,
): ):
"""check softmax available""" """check softmax available"""
if softmax_type is SoftmaxType.SCALED: if softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
return False
if softmax_fusion_type is SoftmaxFusionType.SCALED:
return ScaledSoftmaxFwdPrimitive.is_kernel_available( return ScaledSoftmaxFwdPrimitive.is_kernel_available(
batch, heads, q_seqlen, k_seqlen, dtype batch, heads, q_seqlen, k_seqlen, dtype
) )
if softmax_type is SoftmaxType.SCALED_MASKED: if softmax_fusion_type is SoftmaxFusionType.SCALED_MASKED:
return ScaledMaskedSoftmaxFwdPrimitive.is_kernel_available( return ScaledMaskedSoftmaxFwdPrimitive.is_kernel_available(
batch, heads, q_seqlen, k_seqlen, dtype batch, heads, q_seqlen, k_seqlen, dtype
) )
if softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: if softmax_fusion_type is SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available( return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
batch, heads, q_seqlen, k_seqlen, dtype batch, heads, q_seqlen, k_seqlen, dtype
) )
...@@ -792,26 +797,77 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -792,26 +797,77 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
def jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float): def jax_scaled_softmax(
logits: jnp.ndarray, scale_factor: float, softmax_offset: jnp.ndarray | float | None = None
):
""" """
JAX based implementation of scaled softmax JAX based implementation of scaled softmax
""" """
if softmax_offset is not None:
return jax_general_softmax(scale_factor * logits, offset=softmax_offset)
return jax.nn.softmax(scale_factor * logits) return jax.nn.softmax(scale_factor * logits)
def jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float): def jax_scaled_masked_softmax(
logits: jnp.ndarray,
mask: jnp.ndarray,
scale_factor: float,
softmax_offset: jnp.ndarray | float | None = None,
):
""" """
JAX based implementation of scaled and masked softmax JAX based implementation of scaled and masked softmax
""" """
if softmax_offset is not None:
return jax_general_softmax(logits * scale_factor, offset=softmax_offset, where=mask != 1)
return jax.nn.softmax(logits * scale_factor, where=mask != 1) return jax.nn.softmax(logits * scale_factor, where=mask != 1)
def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float): def jax_scaled_upper_triang_masked_softmax(
logits: jnp.ndarray, scale_factor: float, softmax_offset: jnp.ndarray | float | None = None
):
""" """
JAX based implementation of scaled and upper triangle masked softmax JAX based implementation of scaled and upper triangle masked softmax
""" """
mask = 1 - jnp.tril(jnp.ones_like(logits)) mask = 1 - jnp.tril(jnp.ones_like(logits))
return jax_scaled_masked_softmax(logits, mask, scale_factor) return jax_scaled_masked_softmax(logits, mask, scale_factor, softmax_offset)
def jax_general_softmax(
x: jnp.ndarray,
axis: int = -1,
where: jnp.ndarray | None = None,
initial: jnp.ndarray = -jnp.inf,
offset: jnp.ndarray | float | None = None,
) -> jnp.ndarray:
"""
JAX based implementation of general softmax with optional masking and offset.
"""
# Compute max of x
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
if offset is not None:
# Cast offset to x.dtype to prevent type promotion
if isinstance(offset, (int, float)):
offset = jnp.array(offset, dtype=x.dtype)
else:
offset = offset.astype(x.dtype)
# Include offset in max: x_max = max(x_max, offset)
# This is equivalent to computing max over [x..., offset]
x_max = jnp.maximum(x_max, offset)
x_safe = x if where is None else jnp.where(where, x, initial)
unnormalized = jnp.exp(x_safe - x_max)
denominator = jnp.sum(unnormalized, axis, where=where, keepdims=True)
if offset is not None:
# Add exp(offset - x_max) to denominator
denominator = denominator + jnp.exp(offset - x_max)
result = unnormalized / denominator
if where is not None:
result = jnp.where(where, result, 0)
return result
def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
......
...@@ -108,28 +108,28 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler); ...@@ -108,28 +108,28 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
size_t q_num_heads, size_t kv_num_heads, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen,
size_t q_max_seqlen, size_t kv_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left,
size_t qk_head_dim, size_t v_head_dim, int64_t window_size_right);
int64_t window_size_left, int64_t window_size_right);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes( pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right); DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq,
int64_t window_size_right); int64_t window_size_left, int64_t window_size_right);
// GEMM // GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);
......
...@@ -11,14 +11,12 @@ ...@@ -11,14 +11,12 @@
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
size_t q_attn_heads, size_t kv_attn_heads, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen,
size_t q_max_seqlen, size_t kv_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left,
size_t qk_head_dim, size_t v_head_dim, int64_t window_size_right) {
int64_t window_size_left, int64_t window_size_right) {
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads,
...@@ -39,7 +37,8 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t ...@@ -39,7 +37,8 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
const size_t kv_max_seqlen, DType dtype, const size_t kv_max_seqlen, DType dtype,
NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend, NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend,
void *softmax_buf, void *rng_state_buf = nullptr, void *softmax_buf, void *rng_state_buf = nullptr,
void *bias_buf = nullptr) { void *bias_buf = nullptr,
void *softmax_offset_buf = nullptr) {
// all backends need softmax but expect different shapes/dtypes // all backends need softmax but expect different shapes/dtypes
// start with the max512 sequence length softmax shape/dtype and correct later // start with the max512 sequence length softmax shape/dtype and correct later
tensor_pack->size = 1; tensor_pack->size = 1;
...@@ -67,10 +66,12 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t ...@@ -67,10 +66,12 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
softmax_aux_data.shape.data[3] = 1; // {B,H,Qs,Ks} -> {B,H,Qs,1} softmax_aux_data.shape.data[3] = 1; // {B,H,Qs,Ks} -> {B,H,Qs,1}
softmax_aux_data.dtype = static_cast<NVTEDType>(DType::kFloat32); softmax_aux_data.dtype = static_cast<NVTEDType>(DType::kFloat32);
int size = 2; // Start at 2 (we have softmax and rng_state at indices 0, 1)
// include bias if enabled // include bias if enabled
if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) { if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) {
tensor_pack->size = 3; NVTETensor &bias_aux = tensor_pack->tensors[size];
NVTETensor &bias_aux = tensor_pack->tensors[2]; size++;
NVTEBasicTensor bias_aux_data; NVTEBasicTensor bias_aux_data;
bias_aux_data.data_ptr = bias_buf; bias_aux_data.data_ptr = bias_buf;
bias_aux_data.shape.ndim = 4; bias_aux_data.shape.ndim = 4;
...@@ -81,6 +82,24 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t ...@@ -81,6 +82,24 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
bias_aux_data.dtype = static_cast<NVTEDType>(dtype); bias_aux_data.dtype = static_cast<NVTEDType>(dtype);
nvte_set_tensor_param(&bias_aux, kNVTERowwiseData, &bias_aux_data); nvte_set_tensor_param(&bias_aux, kNVTERowwiseData, &bias_aux_data);
} }
// include softmax_offset if provided
if (softmax_offset_buf != nullptr) {
NVTETensor &softmax_offset_aux = tensor_pack->tensors[size];
size++;
NVTEBasicTensor softmax_offset_aux_data;
softmax_offset_aux_data.data_ptr = softmax_offset_buf;
softmax_offset_aux_data.shape.ndim = 4;
softmax_offset_aux_data.shape.data[0] = 1;
softmax_offset_aux_data.shape.data[1] = attn_heads;
softmax_offset_aux_data.shape.data[2] = 1;
softmax_offset_aux_data.shape.data[3] = 1;
softmax_offset_aux_data.dtype = static_cast<NVTEDType>(DType::kFloat32);
nvte_set_tensor_param(&softmax_offset_aux, kNVTERowwiseData, &softmax_offset_aux_data);
}
// Set final size
tensor_pack->size = size;
} }
nvte_set_tensor_param(&softmax_aux, kNVTERowwiseData, &softmax_aux_data); nvte_set_tensor_param(&softmax_aux, kNVTERowwiseData, &softmax_aux_data);
} }
...@@ -98,14 +117,16 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_ ...@@ -98,14 +117,16 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_
const size_t bias_heads, const size_t q_max_seqlen, const size_t bias_heads, const size_t q_max_seqlen,
const size_t kv_max_seqlen, DType dtype, const size_t kv_max_seqlen, DType dtype,
NVTE_Fused_Attn_Backend backend, void *softmax_buf, NVTE_Fused_Attn_Backend backend, void *softmax_buf,
void *rng_state_buf, void *bias_buf) { void *rng_state_buf, void *bias_buf,
void *softmax_offset_buf = nullptr) {
// Backward calls put everything into the tensor pack for every backend // Backward calls put everything into the tensor pack for every backend
// so we set dummy bias_type and backend choices here to follow the correct code path // so we set dummy bias_type and backend choices here to follow the correct code path
auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS; auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
PrepareFusedAttnForwardAuxTensors(tensor_pack, input_batch, bias_batch, attn_heads, bias_heads, PrepareFusedAttnForwardAuxTensors(tensor_pack, input_batch, bias_batch, attn_heads, bias_heads,
q_max_seqlen, kv_max_seqlen, dtype, dummy_bias_type, q_max_seqlen, kv_max_seqlen, dtype, dummy_bias_type,
dummy_backend, softmax_buf, rng_state_buf, bias_buf); dummy_backend, softmax_buf, rng_state_buf, bias_buf,
softmax_offset_buf);
// correct softmax shape for max512 sequence length kernel // correct softmax shape for max512 sequence length kernel
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
...@@ -121,8 +142,9 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -121,8 +142,9 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) { DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
...@@ -141,7 +163,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -141,7 +163,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32); auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
auto dummy_softmax_offset_tensor = auto dummy_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32); TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
...@@ -208,18 +229,21 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -208,18 +229,21 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto layout_group = nvte_get_qkv_layout_group(qkv_layout); auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
static void FusedAttnForwardImpl( static void FusedAttnForwardImpl(
cudaStream_t stream, void *q, void *k, void *v, void *bias, void *seed, void *q_cu_seqlens, cudaStream_t stream, void *q, void *k, void *v, void *bias, void *softmax_offset, void *seed,
void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *output, void *softmax_aux, void *q_cu_seqlens, void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *output,
void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, void *softmax_aux, void *rng_state, void *workspace, size_t input_batch, size_t bias_batch,
size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups,
size_t qk_head_dim, size_t v_head_dim, size_t max_segments_per_seq, size_t wkspace_size, size_t bias_heads, size_t qk_head_dim, size_t v_head_dim, size_t max_segments_per_seq,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) { DType dtype, DType wkspace_dtype, bool is_training, bool deterministic,
int64_t window_size_left, int64_t window_size_right) {
FUSED_ATTN_IMPL_COMMON_BLOCK; FUSED_ATTN_IMPL_COMMON_BLOCK;
/* Input tensors */ /* Input tensors */
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
auto softmax_offset_tensor =
TensorWrapper(softmax_offset, std::vector<size_t>{1, attn_heads, 1, 1}, DType::kFloat32);
if (is_ragged) { if (is_ragged) {
auto output_size = input_batch * q_max_seqlen * attn_heads * v_head_dim; auto output_size = input_batch * q_max_seqlen * attn_heads * v_head_dim;
...@@ -238,10 +262,6 @@ static void FusedAttnForwardImpl( ...@@ -238,10 +262,6 @@ static void FusedAttnForwardImpl(
/* Prepare RNG state */ /* Prepare RNG state */
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64); auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto dummy_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
...@@ -254,7 +274,7 @@ static void FusedAttnForwardImpl( ...@@ -254,7 +274,7 @@ static void FusedAttnForwardImpl(
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, input_batch, bias_batch, attn_heads, PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type,
backend, softmax_aux); backend, softmax_aux, softmax_offset);
/* Call the underlying NVTE API */ /* Call the underlying NVTE API */
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32); auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
...@@ -303,7 +323,7 @@ static void FusedAttnForwardImpl( ...@@ -303,7 +323,7 @@ static void FusedAttnForwardImpl(
nvte_fused_attn_fwd( nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
...@@ -332,6 +352,8 @@ static void FusedAttnForwardImpl( ...@@ -332,6 +352,8 @@ static void FusedAttnForwardImpl(
static_cast<NVTE_Bias_Type>(get_attr_value<int64_t>(attrs, "bias_type")); \ static_cast<NVTE_Bias_Type>(get_attr_value<int64_t>(attrs, "bias_type")); \
NVTE_Mask_Type mask_type = \ NVTE_Mask_Type mask_type = \
static_cast<NVTE_Mask_Type>(get_attr_value<int64_t>(attrs, "mask_type")); \ static_cast<NVTE_Mask_Type>(get_attr_value<int64_t>(attrs, "mask_type")); \
NVTE_Softmax_Type softmax_type = \
static_cast<NVTE_Softmax_Type>(get_attr_value<int64_t>(attrs, "softmax_type")); \
NVTE_QKV_Layout qkv_layout = \ NVTE_QKV_Layout qkv_layout = \
static_cast<NVTE_QKV_Layout>(get_attr_value<int64_t>(attrs, "qkv_layout")); \ static_cast<NVTE_QKV_Layout>(get_attr_value<int64_t>(attrs, "qkv_layout")); \
bool is_training = get_attr_value<bool>(attrs, "is_training"); \ bool is_training = get_attr_value<bool>(attrs, "is_training"); \
...@@ -342,7 +364,8 @@ static void FusedAttnForwardImpl( ...@@ -342,7 +364,8 @@ static void FusedAttnForwardImpl(
DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf,
Buffer_Type v_buf, Buffer_Type bias_buf, Buffer_Type seed_buf, Buffer_Type v_buf, Buffer_Type bias_buf,
Buffer_Type softmax_offset_buf, Buffer_Type seed_buf,
Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf,
Variadic_Buffer_Type _unused_args, Result_Type output_buf, Variadic_Buffer_Type _unused_args, Result_Type output_buf,
...@@ -352,15 +375,15 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty ...@@ -352,15 +375,15 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty
FusedAttnForwardImpl( FusedAttnForwardImpl(
stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(),
bias_buf.untyped_data(), seed_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), bias_buf.untyped_data(), softmax_offset_buf.untyped_data(), seed_buf.untyped_data(),
kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, q_cu_seqlens_buf.untyped_data(), kv_cu_seqlens_buf.untyped_data(),
is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr,
is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, output_buf->untyped_data(), is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, output_buf->untyped_data(),
softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), workspace_buf->untyped_data(), softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), workspace_buf->untyped_data(),
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor, qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor,
dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype, is_training, dropout_probability, bias_type, mask_type, softmax_type, qkv_layout, dtype, wkspace_dtype,
deterministic, window_size_left, window_size_right); is_training, deterministic, window_size_left, window_size_right);
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
...@@ -371,6 +394,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI, ...@@ -371,6 +394,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI,
.Arg<Buffer_Type>() // k .Arg<Buffer_Type>() // k
.Arg<Buffer_Type>() // v .Arg<Buffer_Type>() // v
.Arg<Buffer_Type>() // bias .Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // softmax_offset
.Arg<Buffer_Type>() // seed_buf .Arg<Buffer_Type>() // seed_buf
.Arg<Buffer_Type>() // q_cu_seqlens .Arg<Buffer_Type>() // q_cu_seqlens
.Arg<Buffer_Type>() // kv_cu_seqlens .Arg<Buffer_Type>() // kv_cu_seqlens
...@@ -388,9 +412,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -388,9 +412,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq,
int64_t window_size_right) { int64_t window_size_left, int64_t window_size_right) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
...@@ -425,9 +449,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -425,9 +449,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
// For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0 // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
min_num_segments = input_batch * max_segments_per_seq; min_num_segments = input_batch * max_segments_per_seq;
} }
auto dummy_d_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32); TensorWrapper dummy_d_softmax_offset_tensor;
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; if (softmax_type == NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX ||
softmax_type == NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX) {
dummy_d_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1, attn_heads, 1, 1}, DType::kFloat32);
}
for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) { for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) {
// the last one is the largest which will be the returned workspace size // the last one is the largest which will be the returned workspace size
auto q_cu_seqlens_tensor = auto q_cu_seqlens_tensor =
...@@ -457,15 +486,16 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -457,15 +486,16 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
} }
static void FusedAttnBackwardImpl( static void FusedAttnBackwardImpl(
cudaStream_t stream, void *q, void *k, void *v, void *bias, void *softmax_aux, void *rng_state, cudaStream_t stream, void *q, void *k, void *v, void *bias, void *softmax_offset,
void *output, void *doutput, void *q_cu_seqlens, void *kv_cu_seqlens, void *q_seq_offsets, void *softmax_aux, void *rng_state, void *output, void *doutput, void *q_cu_seqlens,
void *k_seq_offsets, void *dq, void *dk, void *dv, void *dbias, void *workspace, void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *dq, void *dk, void *dv,
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, void *dbias, void *dsoftmax_offset, void *workspace, size_t input_batch, size_t bias_batch,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups,
size_t v_head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, size_t bias_heads, size_t qk_head_dim, size_t v_head_dim, size_t max_segments_per_seq,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
bool deterministic, int64_t window_size_left, int64_t window_size_right) { DType dtype, DType wkspace_dtype, bool is_training, bool deterministic,
int64_t window_size_left, int64_t window_size_right) {
FUSED_ATTN_IMPL_COMMON_BLOCK; FUSED_ATTN_IMPL_COMMON_BLOCK;
/* Input tensors */ /* Input tensors */
...@@ -476,9 +506,13 @@ static void FusedAttnBackwardImpl( ...@@ -476,9 +506,13 @@ static void FusedAttnBackwardImpl(
/* Output tensors */ /* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16 auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto dummy_d_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32); TensorWrapper dsoftmax_offset_tensor;
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; if (softmax_type == NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX ||
softmax_type == NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX) {
dsoftmax_offset_tensor =
TensorWrapper(dsoftmax_offset, std::vector<size_t>{1, attn_heads, 1, 1}, DType::kFloat32);
}
/* Auxiliary tensors (propagated from the forward pass) */ /* Auxiliary tensors (propagated from the forward pass) */
NVTETensorPack aux_input_tensors; NVTETensorPack aux_input_tensors;
...@@ -490,7 +524,7 @@ static void FusedAttnBackwardImpl( ...@@ -490,7 +524,7 @@ static void FusedAttnBackwardImpl(
false, false); false, false);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias); softmax_aux, rng_state, bias, softmax_offset);
/* Call the underly NVTE API */ /* Call the underly NVTE API */
// Prepare Q, K, V pointers and shapes based on layout // Prepare Q, K, V pointers and shapes based on layout
...@@ -564,7 +598,7 @@ static void FusedAttnBackwardImpl( ...@@ -564,7 +598,7 @@ static void FusedAttnBackwardImpl(
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(),
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dsoftmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, deterministic, false, workspace_tensor.data(), stream); window_size_left, window_size_right, deterministic, false, workspace_tensor.data(), stream);
...@@ -574,26 +608,29 @@ static void FusedAttnBackwardImpl( ...@@ -574,26 +608,29 @@ static void FusedAttnBackwardImpl(
Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf,
Buffer_Type v_buf, Buffer_Type bias_buf, Buffer_Type v_buf, Buffer_Type bias_buf,
Buffer_Type softmax_aux_buf, Buffer_Type rng_state_buf, Buffer_Type softmax_offset_buf, Buffer_Type softmax_aux_buf,
Buffer_Type output_buf, Buffer_Type doutput_buf, Buffer_Type rng_state_buf, Buffer_Type output_buf,
Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, Buffer_Type doutput_buf, Buffer_Type q_cu_seqlens_buf,
Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, Buffer_Type kv_cu_seqlens_buf, Buffer_Type q_seq_offsets_buf,
Variadic_Buffer_Type _unused_args, Result_Type dq_buf, Buffer_Type k_seq_offsets_buf, Variadic_Buffer_Type _unused_args,
Result_Type dk_buf, Result_Type dv_buf, Result_Type dbias_buf, Result_Type dq_buf, Result_Type dk_buf, Result_Type dv_buf,
Result_Type dbias_buf, Result_Type dsoftmax_offset_buf,
Result_Type workspace_buf, Dictionary attrs) { Result_Type workspace_buf, Dictionary attrs) {
FUSED_ATTN_FFI_GET_ATTRS; FUSED_ATTN_FFI_GET_ATTRS;
FusedAttnBackwardImpl( FusedAttnBackwardImpl(
stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(),
bias_buf.untyped_data(), softmax_aux_buf.untyped_data(), rng_state_buf.untyped_data(), bias_buf.untyped_data(), softmax_offset_buf.untyped_data(), softmax_aux_buf.untyped_data(),
output_buf.untyped_data(), doutput_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), rng_state_buf.untyped_data(), output_buf.untyped_data(), doutput_buf.untyped_data(),
kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, q_cu_seqlens_buf.untyped_data(), kv_cu_seqlens_buf.untyped_data(),
is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr,
is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, dq_buf->untyped_data(), is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, dq_buf->untyped_data(),
dk_buf->untyped_data(), dv_buf->untyped_data(), dbias_buf->untyped_data(), dk_buf->untyped_data(), dv_buf->untyped_data(), dbias_buf->untyped_data(),
workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, dsoftmax_offset_buf->untyped_data(), workspace_buf->untyped_data(), input_batch, bias_batch,
attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim, max_segments_per_seq, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim,
wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type,
wkspace_dtype, is_training, deterministic, window_size_left, window_size_right); softmax_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic, window_size_left,
window_size_right);
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
...@@ -605,6 +642,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI, ...@@ -605,6 +642,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI,
.Arg<Buffer_Type>() // k .Arg<Buffer_Type>() // k
.Arg<Buffer_Type>() // v .Arg<Buffer_Type>() // v
.Arg<Buffer_Type>() // bias .Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // softmax_offset
.Arg<Buffer_Type>() // softmax_aux .Arg<Buffer_Type>() // softmax_aux
.Arg<Buffer_Type>() // rng_state .Arg<Buffer_Type>() // rng_state
.Arg<Buffer_Type>() // output .Arg<Buffer_Type>() // output
...@@ -618,6 +656,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI, ...@@ -618,6 +656,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI,
.Ret<Buffer_Type>() // dk .Ret<Buffer_Type>() // dk
.Ret<Buffer_Type>() // dv .Ret<Buffer_Type>() // dv
.Ret<Buffer_Type>() // dbias .Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // dsoftmax_offset
.Ret<Buffer_Type>() // workspace .Ret<Buffer_Type>() // workspace
.Attrs(), .Attrs(),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
......
...@@ -142,6 +142,11 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -142,6 +142,11 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD)
.value("NVTE_THD", NVTE_QKV_Format::NVTE_THD); .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD);
pybind11::enum_<NVTE_Softmax_Type>(m, "NVTE_Softmax_Type", pybind11::module_local())
.value("NVTE_VANILLA_SOFTMAX", NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)
.value("NVTE_OFF_BY_ONE_SOFTMAX", NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX)
.value("NVTE_LEARNABLE_SOFTMAX", NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX);
pybind11::enum_<NVTE_Activation_Type>(m, "NVTE_Activation_Type", pybind11::module_local()) pybind11::enum_<NVTE_Activation_Type>(m, "NVTE_Activation_Type", pybind11::module_local())
.value("GELU", NVTE_Activation_Type::GELU) .value("GELU", NVTE_Activation_Type::GELU)
.value("GEGLU", NVTE_Activation_Type::GEGLU) .value("GEGLU", NVTE_Activation_Type::GEGLU)
......
...@@ -7,6 +7,7 @@ Wrapper module for Transformer related layers with FP8 support. ...@@ -7,6 +7,7 @@ Wrapper module for Transformer related layers with FP8 support.
from functools import reduce from functools import reduce
import operator import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType, Optional from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType, Optional
import warnings
import numpy as np import numpy as np
import jax.numpy as jnp import jax.numpy as jnp
...@@ -23,8 +24,9 @@ from ..layernorm import layernorm ...@@ -23,8 +24,9 @@ from ..layernorm import layernorm
from ..layernorm_dense import layernorm_dense from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp from ..layernorm_mlp import layernorm_mlp
from ..activation import activation from ..activation import activation
from ..softmax import softmax, SoftmaxType from ..softmax import softmax, SoftmaxFusionType
from ..sharding import with_sharding_constraint_by_logical_axes from ..sharding import with_sharding_constraint_by_logical_axes
from ..attention import AttnSoftmaxType
from ..cpp_extensions import ( from ..cpp_extensions import (
is_softmax_kernel_available, is_softmax_kernel_available,
jax_scaled_softmax, jax_scaled_softmax,
...@@ -171,15 +173,20 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods ...@@ -171,15 +173,20 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
---------- ----------
scale_factor : float, default = 1.0 scale_factor : float, default = 1.0
Scalar for the input to softmax. Scalar for the input to softmax.
softmax_type : SoftmaxType, default = SoftmaxType.SCALED softmax_fusion_type : SoftmaxFusionType, default = SoftmaxFusionType.SCALED
Indicate the type of softmax.
softmax_type : AttnSoftmaxType, default = AttnSoftmaxType.VANILLA_SOFTMAX
Indicate the type of softmax. Indicate the type of softmax.
""" """
scale_factor: float = 1.0 scale_factor: float = 1.0
softmax_type: SoftmaxType = SoftmaxType.SCALED softmax_fusion_type: SoftmaxFusionType = SoftmaxFusionType.SCALED
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
@nn.compact @nn.compact
def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray: def __call__(
self, inputs: Array, mask: Array = None, bias: Array = None, softmax_offset: Array = None
) -> jnp.ndarray:
batch = inputs.shape[0] batch = inputs.shape[0]
heads = inputs.shape[1] heads = inputs.shape[1]
q_seqlen = inputs.shape[2] q_seqlen = inputs.shape[2]
...@@ -187,33 +194,52 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods ...@@ -187,33 +194,52 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
input_dtype = inputs.dtype input_dtype = inputs.dtype
logits = inputs logits = inputs
if softmax_offset is not None:
assert self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX
if self.softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
softmax_offset = 0.0
# use primitives # use primitives
if is_softmax_kernel_available( if is_softmax_kernel_available(
self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype self.softmax_fusion_type,
self.softmax_type,
batch,
heads,
q_seqlen,
k_seqlen,
input_dtype,
): ):
if bias is not None: if bias is not None:
logits = logits + bias.astype(input_dtype) logits = logits + bias.astype(input_dtype)
mask_ = mask mask_ = mask
if self.softmax_type is not SoftmaxType.SCALED_MASKED: if self.softmax_fusion_type is not SoftmaxFusionType.SCALED_MASKED:
mask_ = None mask_ = None
outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type) outputs = softmax(logits, mask_, self.scale_factor, self.softmax_fusion_type)
# use default jax based implementation # use default jax based implementation
else: else:
warnings.warn(
"Using unfused JAX softmax implementation instead of TE fused primitives. ",
UserWarning,
stacklevel=2,
)
if bias is not None: if bias is not None:
logits = logits + bias.astype(input_dtype) logits = logits + bias.astype(input_dtype)
if self.softmax_type is SoftmaxType.SCALED: if self.softmax_fusion_type is SoftmaxFusionType.SCALED:
outputs = jax_scaled_softmax(logits, self.scale_factor) outputs = jax_scaled_softmax(logits, self.scale_factor, softmax_offset)
elif self.softmax_type is SoftmaxType.SCALED_MASKED: elif self.softmax_fusion_type is SoftmaxFusionType.SCALED_MASKED:
outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor) outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor, softmax_offset)
elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: elif self.softmax_fusion_type is SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor) outputs = jax_scaled_upper_triang_masked_softmax(
logits, self.scale_factor, softmax_offset
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED," f"Unsupported softmax fusion: {self.softmax_fusion_type}. softmax_fusion_type"
" SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]" " must be [SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
) )
assert input_dtype == outputs.dtype assert input_dtype == outputs.dtype
return outputs return outputs
......
...@@ -23,11 +23,17 @@ from jax.ad_checkpoint import checkpoint_name ...@@ -23,11 +23,17 @@ from jax.ad_checkpoint import checkpoint_name
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax from .module import LayerNorm, Softmax
from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor from ..attention import (
AttnBiasType,
AttnMaskType,
AttnSoftmaxType,
QKVLayout,
SequenceDescriptor,
)
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
from ..attention import fused_attn from ..attention import fused_attn
from ..attention import CPStrategy from ..attention import CPStrategy
from ..softmax import SoftmaxType from ..softmax import SoftmaxFusionType
from ..sharding import num_of_devices from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
from ..sharding import with_sharding_constraint_by_logical_axes from ..sharding import with_sharding_constraint_by_logical_axes
...@@ -120,6 +126,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -120,6 +126,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
window_size: Optional[Tuple[int, int]] = None window_size: Optional[Tuple[int, int]] = None
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
@nn.compact @nn.compact
def __call__( def __call__(
...@@ -145,6 +152,22 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -145,6 +152,22 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
input_dtype = query.dtype input_dtype = query.dtype
# Infer number of attention heads from query shape
# query shape: [..., h, d] where h is num_attention_heads
num_attention_heads = query.shape[-2]
# Initialize softmax_offset for learnable softmax
# Note: OFF_BY_ONE_SOFTMAX is handled internally by the Softmax module
softmax_offset = None
if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
# For learnable softmax, create a learnable parameter with proper sharding and shape (1, h, 1, 1)
softmax_offset = self.param(
"softmax_offset",
nn.with_logical_partitioning(nn.initializers.zeros, (None, HEAD_AXES, None, None)),
(1, num_attention_heads, 1, 1),
jnp.float32,
)
if self.scale_factor is None: if self.scale_factor is None:
scale_factor = 1.0 / sqrt(query.shape[-1]) scale_factor = 1.0 / sqrt(query.shape[-1])
else: else:
...@@ -213,8 +236,8 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -213,8 +236,8 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
new_mask = jnp.where(original_mask == 0, swa_mask, original_mask) new_mask = jnp.where(original_mask == 0, swa_mask, original_mask)
return new_mask return new_mask
def convert_to_softmax_type(attn_mask_type, mask): def convert_to_softmax_fusion_type(attn_mask_type, mask):
"""Convert the attn_mask_type to SoftmaxType""" """Convert the attn_mask_type to SoftmaxFusionType"""
# mask is ignored for no_mask and causal_mask without sliding window # mask is ignored for no_mask and causal_mask without sliding window
if attn_mask_type == AttnMaskType.NO_MASK: if attn_mask_type == AttnMaskType.NO_MASK:
mask = None mask = None
...@@ -224,21 +247,23 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -224,21 +247,23 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
mask = apply_swa_mask(mask) mask = apply_swa_mask(mask)
# Currently cuDNN backend only supports SWA for causal/padding_causal, follow this # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
if mask is not None: if mask is not None:
return SoftmaxType.SCALED_MASKED, mask return SoftmaxFusionType.SCALED_MASKED, mask
if attn_mask_type is AttnMaskType.CAUSAL_MASK: if attn_mask_type is AttnMaskType.CAUSAL_MASK:
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask return SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, mask
if attn_mask_type is AttnMaskType.NO_MASK: if attn_mask_type is AttnMaskType.NO_MASK:
return SoftmaxType.SCALED, mask return SoftmaxFusionType.SCALED, mask
raise ValueError( raise ValueError(
f"Unsupported {attn_mask_type=}, supported attn_mask_type=" f"Unsupported {attn_mask_type=}, supported attn_mask_type="
"{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}" "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}"
) )
softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask) softmax_fusion_type, mask = convert_to_softmax_fusion_type(self.attn_mask_type, mask)
attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)( attn_weights = Softmax(
attn_weights, mask, bias softmax_fusion_type=softmax_fusion_type,
).astype(input_dtype) softmax_type=self.softmax_type,
scale_factor=fused_scale_factor,
)(attn_weights, mask, bias, softmax_offset=softmax_offset).astype(input_dtype)
if is_gqa: if is_gqa:
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
...@@ -279,6 +304,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -279,6 +304,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
context_parallel_axis: str = "" context_parallel_axis: str = ""
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT
context_checkpoint_name: str = "context" context_checkpoint_name: str = "context"
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
@nn.compact @nn.compact
def __call__( def __call__(
...@@ -303,6 +329,17 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -303,6 +329,17 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scale_factor = self.scale_factor scale_factor = self.scale_factor
del self.scale_factor del self.scale_factor
num_attention_heads = query.shape[-2]
softmax_offset = None
if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
# For learnable softmax, create a learnable parameter with proper sharding and shape (1, h, 1, 1)
softmax_offset = self.param(
"softmax_offset",
nn.with_logical_partitioning(nn.initializers.zeros, (None, HEAD_AXES, None, None)),
(1, num_attention_heads, 1, 1),
jnp.float32,
)
if self.qkv_layout.is_qkvpacked(): if self.qkv_layout.is_qkvpacked():
"""qkvpacked format, treat """qkvpacked format, treat
query: qkvpacked tensor, shape = [..., 3, h, d] query: qkvpacked tensor, shape = [..., 3, h, d]
...@@ -320,6 +357,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -320,6 +357,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout, qkv_layout=self.qkv_layout,
softmax_type=self.softmax_type,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
...@@ -329,6 +367,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -329,6 +367,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy, context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name, context_checkpoint_name=self.context_checkpoint_name,
softmax_offset=softmax_offset,
) )
elif self.qkv_layout.is_kvpacked(): elif self.qkv_layout.is_kvpacked():
"""kvpacked format, treat """kvpacked format, treat
...@@ -348,6 +387,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -348,6 +387,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout, qkv_layout=self.qkv_layout,
softmax_type=self.softmax_type,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
...@@ -357,6 +397,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -357,6 +397,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy, context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name, context_checkpoint_name=self.context_checkpoint_name,
softmax_offset=softmax_offset,
) )
elif self.qkv_layout.is_separate(): elif self.qkv_layout.is_separate():
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
...@@ -371,6 +412,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -371,6 +412,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout, qkv_layout=self.qkv_layout,
softmax_type=self.softmax_type,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
...@@ -380,6 +422,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -380,6 +422,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy, context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name, context_checkpoint_name=self.context_checkpoint_name,
softmax_offset=softmax_offset,
) )
else: else:
raise ValueError(f"Unsupported {self.qkv_layout=}.") raise ValueError(f"Unsupported {self.qkv_layout=}.")
...@@ -514,6 +557,17 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -514,6 +557,17 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
context_parallel_axis (str): The name of the context parallel axis. context_parallel_axis (str): The name of the context parallel axis.
context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING. context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING.
context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention. context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -539,6 +593,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -539,6 +593,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
context_parallel_axis: str = "" context_parallel_axis: str = ""
context_parallel_strategy: str = "DEFAULT" context_parallel_strategy: str = "DEFAULT"
context_checkpoint_name: str = "context" context_checkpoint_name: str = "context"
softmax_type: str = "vanilla"
@nn.compact @nn.compact
def __call__( def __call__(
...@@ -595,6 +650,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -595,6 +650,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_bias_type = AttnBiasType[self.attn_bias_type.upper()] attn_bias_type = AttnBiasType[self.attn_bias_type.upper()]
attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type) attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
qkv_layout = QKVLayout[self.qkv_layout.upper()] qkv_layout = QKVLayout[self.qkv_layout.upper()]
softmax_type = AttnSoftmaxType.from_str(self.softmax_type)
del self.attn_bias_type, self.attn_mask_type, self.qkv_layout del self.attn_bias_type, self.attn_mask_type, self.qkv_layout
if attn_bias_type == AttnBiasType.NO_BIAS: if attn_bias_type == AttnBiasType.NO_BIAS:
...@@ -626,6 +682,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -626,6 +682,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
self.attention_dropout, self.attention_dropout,
self.num_attention_heads, self.num_attention_heads,
self.num_gqa_groups, self.num_gqa_groups,
...@@ -702,6 +759,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -702,6 +759,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor=scale_factor, scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
window_size=self.window_size, window_size=self.window_size,
softmax_type=softmax_type,
)( )(
query, query,
key, key,
...@@ -726,6 +784,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -726,6 +784,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=context_parallel_strategy, context_parallel_strategy=context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name, context_checkpoint_name=self.context_checkpoint_name,
softmax_type=softmax_type,
)( )(
query, query,
key, key,
...@@ -1005,6 +1064,17 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1005,6 +1064,17 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
Deprecated. Please refer `fuse_qkv_params` Deprecated. Please refer `fuse_qkv_params`
window_size: Optional[Tuple[int, int]], default = None window_size: Optional[Tuple[int, int]], default = None
Sliding window size. Default value is no sliding window. Sliding window size. Default value is no sliding window.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
""" """
head_dim: int head_dim: int
...@@ -1036,6 +1106,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1036,6 +1106,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
scaled_query_init: bool = True scaled_query_init: bool = True
float32_logits: bool = False float32_logits: bool = False
window_size: Optional[Tuple[int, int]] = None window_size: Optional[Tuple[int, int]] = None
softmax_type: str = "vanilla"
# Deprecated parameters # Deprecated parameters
num_heads: Optional[int] = None num_heads: Optional[int] = None
...@@ -1440,6 +1511,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1440,6 +1511,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor=scale_factor, scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
window_size=self.window_size, window_size=self.window_size,
softmax_type=self.softmax_type,
)(*dpa_args, mask, bias, deterministic=deterministic) )(*dpa_args, mask, bias, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
...@@ -1721,6 +1793,18 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1721,6 +1793,18 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Whether to enable sequence parallelism to operations except dot. Whether to enable sequence parallelism to operations except dot.
window_size: Optional[Tuple[int, int]], default = None window_size: Optional[Tuple[int, int]], default = None
Sliding window size. Default value is no sliding window. Sliding window size. Default value is no sliding window.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
Softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Only supported for fused attention backend.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -1786,6 +1870,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1786,6 +1870,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
window_size: Optional[Tuple[int, int]] = None window_size: Optional[Tuple[int, int]] = None
softmax_type: str = "vanilla"
def __post_init__(self): def __post_init__(self):
if self.mha_kernel_init is None: if self.mha_kernel_init is None:
...@@ -1946,6 +2031,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1946,6 +2031,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_init=self.bias_init, bias_init=self.bias_init,
name=mha_name, name=mha_name,
window_size=self.window_size, window_size=self.window_size,
softmax_type=self.softmax_type,
)(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode) )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode)
def hidden_dropout(x, deterministic): def hidden_dropout(x, deterministic):
...@@ -2024,6 +2110,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -2024,6 +2110,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_init=self.bias_init, bias_init=self.bias_init,
name="encoder_decoder_attention", name="encoder_decoder_attention",
window_size=self.window_size, window_size=self.window_size,
softmax_type=self.softmax_type,
)(x, encoded, encoder_decoder_mask, deterministic=deterministic) )(x, encoded, encoder_decoder_mask, deterministic=deterministic)
y = with_sharding_constraint_by_logical_axes( y = with_sharding_constraint_by_logical_axes(
......
...@@ -12,8 +12,8 @@ import jax.numpy as jnp ...@@ -12,8 +12,8 @@ import jax.numpy as jnp
from . import cpp_extensions as tex from . import cpp_extensions as tex
class SoftmaxType(Enum): class SoftmaxFusionType(Enum):
"""SoftmaxType.""" """SoftmaxFusionType."""
SCALED = "scaled" SCALED = "scaled"
SCALED_MASKED = "scaled_masked" SCALED_MASKED = "scaled_masked"
...@@ -24,27 +24,27 @@ def softmax( ...@@ -24,27 +24,27 @@ def softmax(
logits: jnp.ndarray, logits: jnp.ndarray,
mask: Optional[jnp.ndarray] = None, mask: Optional[jnp.ndarray] = None,
scale_factor: Optional[float] = 1.0, scale_factor: Optional[float] = 1.0,
softmax_type: Optional[SoftmaxType] = SoftmaxType.SCALED, softmax_fusion_type: Optional[SoftmaxFusionType] = SoftmaxFusionType.SCALED,
): ):
""" """
Softmax wrapper Softmax wrapper
""" """
output = _softmax(logits, mask, scale_factor, softmax_type) output = _softmax(logits, mask, scale_factor, softmax_fusion_type)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(2, 3)) @partial(jax.custom_vjp, nondiff_argnums=(2, 3))
def _softmax(logits, mask, scale_factor, softmax_type): def _softmax(logits, mask, scale_factor, softmax_fusion_type):
output, _ = _softmax_fwd_rule(logits, mask, scale_factor, softmax_type) output, _ = _softmax_fwd_rule(logits, mask, scale_factor, softmax_fusion_type)
return output return output
def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type): def _softmax_fwd_rule(logits, mask, scale_factor, softmax_fusion_type):
if softmax_type is SoftmaxType.SCALED_MASKED: if softmax_fusion_type is SoftmaxFusionType.SCALED_MASKED:
assert mask is not None assert mask is not None
output = tex.scaled_masked_softmax_fwd(logits, mask, scale_factor) output = tex.scaled_masked_softmax_fwd(logits, mask, scale_factor)
elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: elif softmax_fusion_type is SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
output = tex.scaled_upper_triang_masked_softmax_fwd(logits, scale_factor) output = tex.scaled_upper_triang_masked_softmax_fwd(logits, scale_factor)
else: else:
output = tex.scaled_softmax_fwd(logits, scale_factor) output = tex.scaled_softmax_fwd(logits, scale_factor)
...@@ -52,12 +52,12 @@ def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type): ...@@ -52,12 +52,12 @@ def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type):
return output, (output, logits, mask) return output, (output, logits, mask)
def _softmax_bwd_rule(scale_factor, softmax_type, ctx, dz): def _softmax_bwd_rule(scale_factor, softmax_fusion_type, ctx, dz):
(softmax_output, logits, mask) = ctx (softmax_output, logits, mask) = ctx
if softmax_type is SoftmaxType.SCALED_MASKED: if softmax_fusion_type is SoftmaxFusionType.SCALED_MASKED:
dgrad = tex.scaled_masked_softmax_bwd(dz, softmax_output, logits, mask, scale_factor) dgrad = tex.scaled_masked_softmax_bwd(dz, softmax_output, logits, mask, scale_factor)
elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: elif softmax_fusion_type is SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
dgrad = tex.scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, logits, scale_factor) dgrad = tex.scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, logits, scale_factor)
else: else:
dgrad = tex.scaled_softmax_bwd(dz, softmax_output, logits, scale_factor) dgrad = tex.scaled_softmax_bwd(dz, softmax_output, logits, scale_factor)
......
...@@ -156,7 +156,9 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -156,7 +156,9 @@ class FusedScaleMaskSoftmax(nn.Module):
softmax_in_fp32: bool = True, softmax_in_fp32: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
self.scaled_masked_softmax_fusion = bool(int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1"))) self.scaled_masked_softmax_fusion_type = bool(
int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1"))
)
self.mask_func = mask_func self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32 self.softmax_in_fp32 = softmax_in_fp32
...@@ -189,7 +191,7 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -189,7 +191,7 @@ class FusedScaleMaskSoftmax(nn.Module):
"""Check FusedScaleMaskSoftmax kernel availability based on size""" """Check FusedScaleMaskSoftmax kernel availability based on size"""
attn_batches = b * np attn_batches = b * np
if not self.scaled_masked_softmax_fusion: if not self.scaled_masked_softmax_fusion_type:
return False # user doesn't want to fuse return False # user doesn't want to fuse
if not self.input_in_float16: if not self.input_in_float16:
return False # input must be fp16 return False # input must be fp16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment