Unverified Commit 15cefbc5 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[JAX] Add support for sink attention in JAX (#2225)



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* removed packed versions
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* jax
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix:
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* sofmtax_fusion -> softmax_fusion_type
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 7e593c3b
......@@ -18,6 +18,7 @@ from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
AttnBiasType,
AttnMaskType,
AttnSoftmaxType,
QKVLayout,
QKVFormat,
reorder_causal_load_balancing,
......@@ -66,6 +67,7 @@ class TestDistributedSelfAttn:
bias_shape,
attn_mask_type,
dtype,
softmax_type,
use_shardy,
):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
......@@ -80,6 +82,7 @@ class TestDistributedSelfAttn:
QKVLayout.BS3HD,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
num_head,
num_head,
......@@ -109,6 +112,7 @@ class TestDistributedSelfAttn:
hidden,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
......@@ -142,6 +146,14 @@ class TestDistributedSelfAttn:
],
)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
def test_self_attn(
self,
device_count,
......@@ -153,6 +165,7 @@ class TestDistributedSelfAttn:
bias_shape,
attn_mask_type,
dtype,
softmax_type,
):
self.impl_test_self_attn(
device_count,
......@@ -164,6 +177,7 @@ class TestDistributedSelfAttn:
bias_shape,
attn_mask_type,
dtype,
softmax_type,
use_shardy=False,
)
......@@ -175,8 +189,23 @@ class TestDistributedSelfAttn:
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
],
)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
def test_self_attn_shardy(
self, device_count, mesh_shape, mesh_axes, mesh_resource, attn_bias_type, bias_shape
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
attn_bias_type,
bias_shape,
softmax_type,
):
data_shape = (32, 512, 12, 64)
self.impl_test_self_attn(
......@@ -189,6 +218,7 @@ class TestDistributedSelfAttn:
bias_shape,
AttnMaskType.PADDING_MASK,
jnp.bfloat16,
softmax_type,
use_shardy=True,
)
......@@ -213,8 +243,24 @@ class TestDistributedCrossAttn:
"attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
def test_cross_attn(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_mask_type,
dtype,
softmax_type,
):
attn_bias_type = AttnBiasType.NO_BIAS
bias_shape = None
......@@ -230,6 +276,7 @@ class TestDistributedCrossAttn:
QKVLayout.BSHD_BS2HD,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
num_head,
num_head,
......@@ -252,6 +299,7 @@ class TestDistributedCrossAttn:
hidden,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
......@@ -322,6 +370,8 @@ class TestDistributedContextParallelSelfAttn:
bias_shape = None
dropout_prob = 0.0
is_training = True
# Context parallel does not support softmax_offset
softmax_type = AttnSoftmaxType.VANILLA_SOFTMAX
dp_size, cp_size, tp_size = mesh_shape
batch, seqlen, num_head, hidden = data_shape
......@@ -343,6 +393,7 @@ class TestDistributedContextParallelSelfAttn:
hidden,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
......@@ -366,6 +417,7 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
attn_bias_type,
mask_type,
softmax_type,
dropout_prob,
num_head,
num_kv_heads,
......
......@@ -16,7 +16,7 @@ from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import autocast
from transformer_engine.jax.softmax import SoftmaxType, softmax
from transformer_engine.jax.softmax import SoftmaxFusionType, softmax
DTYPES = [jnp.float16, jnp.bfloat16]
......@@ -29,12 +29,12 @@ class TestDistributedSoftmax:
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(
self, shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask
self, shape, mesh_resource, softmax_fusion_type, dtype, bad_sharding, broadcast_batch_mask
):
batch, _, sqelen, _ = shape
x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
if softmax_type == SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
if softmax_fusion_type == SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
mask = make_causal_mask(batch, sqelen)
else:
mask = make_self_mask(1 if broadcast_batch_mask else batch, sqelen)
......@@ -56,8 +56,10 @@ class TestDistributedSoftmax:
return (x, mask), (x_pspec, mask_pspec)
@staticmethod
def target_func(x, mask, scale_factor=1.0, softmax_type=SoftmaxType.SCALED):
return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type))
def target_func(x, mask, scale_factor=1.0, softmax_fusion_type=SoftmaxFusionType.SCALED):
return jnp.mean(
softmax(x, mask, scale_factor=scale_factor, softmax_fusion_type=softmax_fusion_type)
)
@staticmethod
def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16):
......@@ -80,24 +82,29 @@ class TestDistributedSoftmax:
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
softmax_fusion_type,
scale_factor,
dtype,
bad_sharding,
broadcast_batch_mask,
use_shardy,
):
if broadcast_batch_mask and softmax_type != SoftmaxType.SCALED_MASKED:
if broadcast_batch_mask and softmax_fusion_type != SoftmaxFusionType.SCALED_MASKED:
pytest.skip("Softmax type has no mask.")
jax.config.update("jax_use_shardy_partitioner", use_shardy)
target_func = partial(
self.target_func, scale_factor=scale_factor, softmax_type=softmax_type
self.target_func, scale_factor=scale_factor, softmax_fusion_type=softmax_fusion_type
)
ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype)
(x, mask), (x_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask
data_shape,
mesh_resource,
softmax_fusion_type,
dtype,
bad_sharding,
broadcast_batch_mask,
)
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
......@@ -139,8 +146,12 @@ class TestDistributedSoftmax:
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [8, 8, 1024, 1024]])
@pytest.mark.parametrize(
"softmax_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
"softmax_fusion_type",
[
SoftmaxFusionType.SCALED,
SoftmaxFusionType.SCALED_MASKED,
SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED,
],
)
@pytest.mark.parametrize("scale_factor", [1.0, 3.0])
@pytest.mark.parametrize("dtype", DTYPES)
......@@ -153,7 +164,7 @@ class TestDistributedSoftmax:
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
softmax_fusion_type,
scale_factor,
dtype,
bad_sharding,
......@@ -165,7 +176,7 @@ class TestDistributedSoftmax:
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
softmax_fusion_type,
scale_factor,
dtype,
bad_sharding,
......@@ -174,7 +185,9 @@ class TestDistributedSoftmax:
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED])
@pytest.mark.parametrize(
"softmax_fusion_type", [SoftmaxFusionType.SCALED, SoftmaxFusionType.SCALED_MASKED]
)
@pytest.mark.parametrize("bad_sharding", [False, True])
@pytest.mark.parametrize("broadcast_batch_mask", [False, True])
def test_softmax_gspmd(
......@@ -183,7 +196,7 @@ class TestDistributedSoftmax:
mesh_shape,
mesh_axes,
mesh_resource,
softmax_type,
softmax_fusion_type,
bad_sharding,
broadcast_batch_mask,
):
......@@ -193,7 +206,7 @@ class TestDistributedSoftmax:
mesh_axes,
mesh_resource,
data_shape=[32, 12, 128, 128],
softmax_type=softmax_type,
softmax_fusion_type=softmax_fusion_type,
scale_factor=1.0,
dtype=DTYPES[0],
bad_sharding=bad_sharding,
......
......@@ -27,6 +27,7 @@ from transformer_engine.jax.sharding import MeshResource
from transformer_engine.jax.attention import (
AttnBiasType,
AttnMaskType,
AttnSoftmaxType,
QKVLayout,
QKVFormat,
reorder_causal_load_balancing,
......@@ -59,14 +60,16 @@ def init():
yield
@partial(jax.jit, static_argnums=(5, 6, 7, 9))
@partial(jax.jit, static_argnums=(6, 7, 8, 9, 11))
def general_dot_product_attention(
query: ArrayLike,
key: ArrayLike,
value: ArrayLike,
softmax_offset: Optional[ArrayLike],
bias: ArrayLike,
mask: ArrayLike,
deterministic: bool,
softmax_type: AttnSoftmaxType,
scale_factor: float,
dropout_rate: float,
dropout_rng: ArrayLike,
......@@ -99,7 +102,25 @@ def general_dot_product_attention(
mask = jnp.expand_dims(mask, axis=-3)
logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
match softmax_type:
case AttnSoftmaxType.VANILLA_SOFTMAX:
softmax_out = jax.nn.softmax(logits).astype(dtype)
case AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
# Softmax with +1 in denominator: exp(x_i) / (sum(exp(x_j)) + 1)
# Append a zero logit, apply standard softmax, then remove last column
zero_logit = jnp.zeros(logits.shape[:-1] + (1,), dtype=logits.dtype)
logits_with_extra = jnp.concatenate([logits, zero_logit], axis=-1)
softmax_with_extra = jax.nn.softmax(logits_with_extra, axis=-1)
softmax_out = softmax_with_extra[..., :-1].astype(dtype)
case AttnSoftmaxType.LEARNABLE_SOFTMAX:
# Append learnable offset logit, apply standard softmax, then remove last column
learnable_logit = softmax_offset.reshape(1, h_kv, num_groups, 1, 1)
learnable_logit = jnp.broadcast_to(learnable_logit, logits.shape[:-1] + (1,))
logits_with_extra = jnp.concatenate([logits, learnable_logit], axis=-1)
softmax_with_extra = jax.nn.softmax(logits_with_extra, axis=-1)
softmax_out = softmax_with_extra[..., :-1].astype(dtype)
case _:
raise NotImplementedError(f"Unknown {softmax_type=}")
if not deterministic and dropout_rate > 0.0:
keep_prob = 1.0 - dropout_rate
......@@ -238,7 +259,7 @@ def _split_valid_and_invalid(primitive, reference, pad):
return primitive_valid, primitive_invalid, reference_valid, reference_invalid
def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
def jax_dpa(query, key, value, bias, softmax_offset, mask, dropout_rng, **kwargs):
"""
JAX native dot product attention implementation
"""
......@@ -246,11 +267,13 @@ def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
query,
key,
value,
softmax_offset,
bias,
mask,
deterministic=not kwargs["is_training"],
scale_factor=kwargs["scaling_factor"],
dropout_rate=kwargs["dropout_probability"],
softmax_type=kwargs["softmax_type"],
dropout_rng=dropout_rng,
dtype=jnp.float32,
)
......@@ -262,6 +285,7 @@ def customcall_fused_dpa(
key,
value,
bias,
softmax_offset,
sequence_descriptor,
dropout_rng,
**kwargs,
......@@ -283,9 +307,9 @@ def customcall_fused_dpa(
qkv_args = (query, key, value)
case _:
raise ValueError(f"Unsupported {qkv_layout=}")
return fused_attn(qkv_args, bias, sequence_descriptor, dropout_rng, **kwargs).astype(
query.dtype
)
return fused_attn(
qkv_args, bias, sequence_descriptor, dropout_rng, softmax_offset=softmax_offset, **kwargs
).astype(query.dtype)
class BiasShape(Enum):
......@@ -320,6 +344,7 @@ class FusedAttnRunner:
head_dim_v: int
attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType
softmax_type: AttnSoftmaxType
dropout_prob: float
dtype: DTypeLike
is_training: bool
......@@ -402,6 +427,7 @@ class FusedAttnRunner:
self.qkv_layout,
self.attn_bias_type,
self.attn_mask_type,
self.softmax_type,
self.dropout_prob,
self.num_heads_q,
self.num_heads_kv,
......@@ -439,7 +465,7 @@ class FusedAttnRunner:
self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1)
key = jax.random.PRNGKey(0)
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
q_key, k_key, v_key, bias_key, dropout_key, softmax_key = jax.random.split(key, 6)
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
......@@ -490,6 +516,13 @@ class FusedAttnRunner:
else:
pad_ratio = 0.0
if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
self.softmax_offset = jax.random.uniform(
softmax_key, (1, self.num_heads_q, 1, 1), jnp.float32, -1.0
)
else:
self.softmax_offset = None
def gen_valid(bs, max_seqlen, pad_ratio):
pad_len = int(max_seqlen * pad_ratio)
valid_len = max_seqlen - pad_len
......@@ -713,6 +746,16 @@ class FusedAttnRunner:
self.bias_pspec = PartitionSpec()
self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)
# Softmax offset sharding (1, num_heads, 1, 1)
# Use the same logic as HEAD_AXES: tpsp_resource if enabled, else tp_resource
head_resource = (
self.mesh_resource.tpsp_resource
if self.mesh_resource.tpsp_resource is not None
else self.mesh_resource.tp_resource
)
self.softmax_offset_pspec = PartitionSpec(None, head_resource, None, None)
self.softmax_offset_sharding = NamedSharding(self.mesh, self.softmax_offset_pspec)
self.dropout_rng_pspec = PartitionSpec(
None,
)
......@@ -732,7 +775,7 @@ class FusedAttnRunner:
"""
self._setup_inputs()
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
customcall_args = [
# Put test data onto each GPU for distributed.
......@@ -742,12 +785,14 @@ class FusedAttnRunner:
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
]
kwargs = {
"attn_bias_type": self.attn_bias_type,
"attn_mask_type": self.attn_mask_type,
"softmax_type": self.softmax_type,
"scaling_factor": self.scaling_factor,
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
......@@ -766,6 +811,7 @@ class FusedAttnRunner:
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
self.softmax_offset_sharding,
self.seq_desc_sharding,
self.dropout_rng_sharding,
],
......@@ -826,7 +872,7 @@ class FusedAttnRunner:
jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
).astype(self.dtype)
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
customcall_args = [
# TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
# THD params once we support those features on CP.
......@@ -834,12 +880,14 @@ class FusedAttnRunner:
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
]
kwargs = {
"attn_bias_type": self.attn_bias_type,
"attn_mask_type": self.attn_mask_type,
"softmax_type": self.softmax_type,
"scaling_factor": self.scaling_factor,
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
......@@ -866,8 +914,16 @@ class FusedAttnRunner:
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
value_and_grad(
lambda q, k, v, bias, *args: grad_func(
customcall_fused_dpa, q, k, v, bias, *args, cp_reverse_out=True, **kwargs
lambda q, k, v, bias, softmax_offset, *args: grad_func(
customcall_fused_dpa,
q,
k,
v,
bias,
softmax_offset,
*args,
cp_reverse_out=True,
**kwargs,
),
arg_nums,
),
......@@ -876,6 +932,7 @@ class FusedAttnRunner:
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
self.softmax_offset_sharding,
self.seq_desc_sharding,
self.dropout_rng_sharding,
),
......@@ -883,7 +940,9 @@ class FusedAttnRunner:
)
jitted_reference = jit(
value_and_grad(
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
lambda q, k, v, bias, softmax_offset, *args: grad_func(
jax_dpa, q, k, v, bias, softmax_offset, *args, **kwargs
),
arg_nums,
)
)
......@@ -976,6 +1035,14 @@ class FusedAttnRunner:
),
],
)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
@pytest.mark.parametrize(
"qkv_layout",
[
......@@ -1084,6 +1151,7 @@ class TestFusedAttn:
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
......@@ -1110,6 +1178,7 @@ class TestFusedAttn:
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
......@@ -1138,6 +1207,7 @@ class TestFusedAttn:
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
qkv_layout,
......@@ -1161,6 +1231,7 @@ class TestFusedAttn:
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
True,
......
......@@ -83,6 +83,7 @@ _KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits"
_KEY_OF_USE_BIAS = "use_bias"
_KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding"
_KEY_OF_WINDOW_SIZE = "window_size"
_KEY_OF_SOFTMAX_TYPE = "softmax_type"
BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
......@@ -276,6 +277,14 @@ ATTRS = [
_KEY_OF_RELATIVE_EMBEDDING: True,
_KEY_OF_SELF_ATTN_BIAS_TYPE: "post_scale_bias",
},
# attrs31
{
_KEY_OF_SOFTMAX_TYPE: "off_by_one",
},
# attrs31
{
_KEY_OF_SOFTMAX_TYPE: "learnable",
},
]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
......@@ -418,6 +427,9 @@ class EncoderRunner(BaseRunner):
"attention/qkv/ln_bias": "pre_attention_layer_norm/ln_bias",
"attention/query/scale": "pre_attention_layer_norm/scale",
"attention/query/ln_bias": "pre_attention_layer_norm/ln_bias",
"attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
"attention/DotProductAttention_0/softmax_offset"
),
"mlp/wi_kernel": "mlp/wi/kernel",
"mlp/wi_bias": "mlp/wi/bias",
"mlp/wo_kernel": "mlp/wo/kernel",
......@@ -463,10 +475,16 @@ class DecoderRunner(BaseRunner):
"encoder_decoder_attention/qkv/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
"encoder_decoder_attention/query/scale": "pre_cross_attention_layer_norm/scale",
"encoder_decoder_attention/query/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
"encoder_decoder_attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
"encoder_decoder_attention/DotProductAttention_0/softmax_offset"
),
"self_attention/qkv/scale": "pre_self_attention_layer_norm/scale",
"self_attention/qkv/ln_bias": "pre_self_attention_layer_norm/ln_bias",
"self_attention/query/scale": "pre_self_attention_layer_norm/scale",
"self_attention/query/ln_bias": "pre_self_attention_layer_norm/ln_bias",
"self_attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
"self_attention/DotProductAttention_0/softmax_offset"
),
"mlp/wi_kernel": "mlp/wi/kernel",
"mlp/wi_bias": "mlp/wi/bias",
"mlp/wo_kernel": "mlp/wo/kernel",
......
......@@ -17,7 +17,8 @@ from jax.typing import DTypeLike
from utils import assert_allclose
from transformer_engine.jax.cpp_extensions import is_softmax_kernel_available
from transformer_engine.jax.softmax import SoftmaxType, softmax
from transformer_engine.jax.cpp_extensions.attention import AttnSoftmaxType
from transformer_engine.jax.softmax import SoftmaxFusionType, softmax
from transformer_engine.jax.flax.module import Softmax
......@@ -50,8 +51,9 @@ class SoftmaxRunner:
max_seqlen_kv: int
num_heads: int
scale_factor: float
softmax_type: SoftmaxType
softmax_fusion_type: SoftmaxFusionType
dtype: DTypeLike
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
@staticmethod
def reference_softmax(logits, mask, scale_factor, **_):
......@@ -68,6 +70,7 @@ class SoftmaxRunner:
def _is_support(self):
return is_softmax_kernel_available(
self.softmax_fusion_type,
self.softmax_type,
self.batch_size,
self.num_heads,
......@@ -85,22 +88,22 @@ class SoftmaxRunner:
self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.0)
match self.softmax_type:
case SoftmaxType.SCALED:
match self.softmax_fusion_type:
case SoftmaxFusionType.SCALED:
self.mask = None
case SoftmaxType.SCALED_MASKED:
case SoftmaxFusionType.SCALED_MASKED:
self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8)
case SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
case SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
self.mask = (1.0 - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8)
case _:
raise ValueError(f"Unknown {self.softmax_type=}")
raise ValueError(f"Unknown {self.softmax_fusion_type=}")
def test_forward(self):
"""
Test transformer_engine.jax.softmax.softmax fwd rule
"""
self._setup_inputs()
primitive_out = softmax(self.logits, self.mask, self.scale_factor, self.softmax_type)
primitive_out = softmax(self.logits, self.mask, self.scale_factor, self.softmax_fusion_type)
reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor)
assert_allclose(primitive_out, reference_out, dtype=self.dtype)
......@@ -117,7 +120,7 @@ class SoftmaxRunner:
args = [self.logits, self.mask]
kwargs = {
"scale_factor": self.scale_factor,
"softmax_type": self.softmax_type,
"softmax_fusion_type": self.softmax_fusion_type,
}
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
......@@ -175,7 +178,7 @@ class SoftmaxModuleRunner:
rng = jax.random.PRNGKey(0)
softmax_module = Softmax(
scale_factor=runner.scale_factor,
softmax_type=runner.softmax_type,
softmax_fusion_type=runner.softmax_fusion_type,
)
softmax_vars = softmax_module.init(rng, runner.logits, runner.mask)
module_out = softmax_module.apply(softmax_vars, runner.logits, runner.mask)
......@@ -194,11 +197,11 @@ class SoftmaxModuleRunner:
)
@pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize(
"softmax_type",
"softmax_fusion_type",
[
pytest.param(SoftmaxType.SCALED, id="SCALED"),
pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
pytest.param(SoftmaxFusionType.SCALED, id="SCALED"),
pytest.param(SoftmaxFusionType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
],
)
@pytest.mark.parametrize(
......@@ -214,19 +217,19 @@ class TestSoftmaxPrimitives:
"""
@staticmethod
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype):
"""
Test forward with parameterized configs
"""
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype)
runner.test_forward()
@staticmethod
def test_backward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
def test_backward(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype):
"""
Test forward with parameterized configs
"""
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype)
runner.test_backward()
......@@ -243,11 +246,11 @@ class TestSoftmaxPrimitives:
)
@pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize(
"softmax_type",
"softmax_fusion_type",
[
pytest.param(SoftmaxType.SCALED, id="SCALED"),
pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
pytest.param(SoftmaxFusionType.SCALED, id="SCALED"),
pytest.param(SoftmaxFusionType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
],
)
@pytest.mark.parametrize(
......@@ -263,11 +266,11 @@ class TestSoftmaxModule:
"""
@staticmethod
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype):
"""
Test forward with parameterized configs
"""
module_runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
module_runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype)
bias = None
runner = SoftmaxModuleRunner(module_runner, bias)
runner.test_forward()
......@@ -21,6 +21,7 @@ from jax import random as jax_random
import pytest
from transformer_engine.jax.attention import (
AttnSoftmaxType,
canonicalize_attn_mask_type,
make_swa_mask,
)
......@@ -162,6 +163,7 @@ class DotProductAttention(nn.Module):
dropout_rate: float = 0.0
dtype: DType = jnp.float32
float32_logits: bool = False
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
"""Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on
......@@ -211,6 +213,24 @@ class DotProductAttention(nn.Module):
assert key.shape[-2] == value.shape[-2], "k, v num_heads must match."
assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
# Infer number of attention heads from query shape
# query shape: [..., h, d] where h is num_attention_heads
num_attention_heads = query.shape[-2]
# Initialize softmax_offset for off-by-one or learnable softmax
softmax_offset = None
if self.softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
# For off-by-one softmax, use zeros with shape (1, h, 1, 1)
softmax_offset = jnp.zeros((1, num_attention_heads, 1, 1), dtype=input_dtype)
elif self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
# For learnable softmax, create a learnable parameter with shape (1, h, 1, 1)
softmax_offset = self.param(
"softmax_offset",
nn.initializers.zeros,
(1, num_attention_heads, 1, 1),
jnp.float32,
)
if self.scale_attn_logits:
head_dim = query.shape[-1]
depth_scaling = jnp.sqrt(head_dim).astype(input_dtype)
......@@ -241,9 +261,23 @@ class DotProductAttention(nn.Module):
if bias is not None:
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
# Add attention sink to the last column if not vanilla softmax
if self.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
# Add extra column with softmax_offset
# softmax_offset shape: (1, h, 1, 1), attn_weights shape: [b, h, q, k]
extra_col = jnp.broadcast_to(
softmax_offset,
(attn_weights.shape[0], attn_weights.shape[1], attn_weights.shape[2], 1),
)
attn_weights = jnp.concatenate([attn_weights, extra_col], axis=-1)
# Normalize the attention weights across `kv_length` dimension.
attn_weights = jax_nn.softmax(attn_weights).astype(input_dtype)
# Remove the extra column after softmax if not vanilla softmax
if self.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
attn_weights = attn_weights[..., :-1]
# Apply attention dropout.
if not deterministic and self.dropout_rate > 0.0:
keep_prob = 1.0 - self.dropout_rate
......@@ -535,6 +569,7 @@ class MultiHeadAttention(nn.Module):
rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv: bool = True
use_bias: bool = False
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
def __post_init__(self):
if self.kernel_init is None:
......@@ -801,6 +836,7 @@ class MultiHeadAttention(nn.Module):
dropout_rate=self.dropout_rate,
dtype=self.dtype,
float32_logits=self.float32_logits,
softmax_type=self.softmax_type,
)(query, key, value, bias=attention_bias, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
......@@ -1058,6 +1094,7 @@ class EncoderLayer(nn.Module):
self_attn_bias_type: Any = None
self_attn_mask_type: str = "no_mask"
window_size: Tuple[int, int] = (-1, -1)
softmax_type: str = "vanilla"
def __post_init__(self):
if self.num_gqa_groups is None:
......@@ -1111,6 +1148,9 @@ class EncoderLayer(nn.Module):
else:
x = inputs
# Convert softmax_type string to AttnSoftmaxType enum
attn_softmax_type = AttnSoftmaxType.from_str(self.softmax_type)
# [batch, length, emb_dim] -> [batch, length, emb_dim]
x = MultiHeadAttention(
num_heads=self.num_attention_heads,
......@@ -1126,6 +1166,7 @@ class EncoderLayer(nn.Module):
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
use_bias=self.use_bias,
softmax_type=attn_softmax_type,
name="attention",
)(x, x, encoder_mask, encoder_bias, deterministic=deterministic)
x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
......@@ -1222,6 +1263,7 @@ class DecoderLayer(nn.Module):
self_attn_bias_type: Any = None
self_attn_mask_type: str = "no_mask"
window_size: Tuple[int, int] = (-1, -1)
softmax_type: str = "vanilla"
def __post_init__(self):
if self.num_gqa_groups is None:
......@@ -1290,6 +1332,9 @@ class DecoderLayer(nn.Module):
else:
x = inputs
# Convert softmax_type string to AttnSoftmaxType enum
attn_softmax_type = AttnSoftmaxType.from_str(self.softmax_type)
# Self-attention block
x = MultiHeadAttention(
num_heads=self.num_attention_heads,
......@@ -1305,6 +1350,7 @@ class DecoderLayer(nn.Module):
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias,
softmax_type=attn_softmax_type,
name="self_attention",
)(x, x, decoder_mask, decoder_bias, deterministic=deterministic, decode=decode)
x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
......@@ -1343,6 +1389,7 @@ class DecoderLayer(nn.Module):
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias,
softmax_type=attn_softmax_type,
name="encoder_decoder_attention",
)(y, encoded, encoder_decoder_mask, deterministic=deterministic)
y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
......
......@@ -18,6 +18,7 @@ from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_QKV_Format
from transformer_engine_jax import nvte_get_qkv_format
from transformer_engine_jax import NVTE_Softmax_Type
from . import cpp_extensions as tex
......@@ -74,6 +75,35 @@ class AttnMaskType(Enum):
]
class AttnSoftmaxType(Enum):
"""
VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)),
LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [H].
"""
VANILLA_SOFTMAX = NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX
OFF_BY_ONE_SOFTMAX = NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX
LEARNABLE_SOFTMAX = NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX
@classmethod
def from_str(cls, softmax_type: str) -> "AttnSoftmaxType":
"""Convert string to AttnSoftmaxType: 'vanilla', 'off_by_one', or 'learnable'."""
softmax_type_map = {
"vanilla": cls.VANILLA_SOFTMAX,
"off_by_one": cls.OFF_BY_ONE_SOFTMAX,
"learnable": cls.LEARNABLE_SOFTMAX,
}
result = softmax_type_map.get(softmax_type)
if result is None:
raise ValueError(
f"Unknown softmax_type: {softmax_type}. "
"Valid options: 'vanilla', 'off_by_one', 'learnable'"
)
return result
class QKVFormat(Enum):
"""
SBHD: q,k,v memory layout with [s, b, ..., h, d]
......@@ -301,6 +331,7 @@ def is_fused_attn_kernel_available(
qkv_layout,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_probability,
q_num_heads,
kv_num_heads,
......@@ -313,6 +344,7 @@ def is_fused_attn_kernel_available(
"""
To check whether the fused attention kernel is supported
"""
window_size_tuple = (-1, -1) if window_size is None else window_size
def make_helper(attn_mask_type):
return tex.FusedAttnHelper(
......@@ -322,6 +354,7 @@ def is_fused_attn_kernel_available(
qkv_layout,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_probability,
q_num_heads,
kv_num_heads,
......@@ -329,7 +362,7 @@ def is_fused_attn_kernel_available(
kv_max_seqlen,
head_dim_qk,
head_dim_v,
(-1, -1) if window_size is None else window_size,
window_size_tuple,
)
return make_helper(attn_mask_type).is_fused_attn_kernel_available()
......@@ -786,6 +819,7 @@ def _legacy_fused_attn(
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout,
softmax_type: AttnSoftmaxType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
......@@ -793,6 +827,7 @@ def _legacy_fused_attn(
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
softmax_offset: Optional[jnp.ndarray] = None,
):
"""
Perform non-THD (non-packed) cuDNN fused attention.
......@@ -815,6 +850,7 @@ def _legacy_fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
......@@ -863,10 +899,12 @@ def _legacy_fused_attn(
output = _fused_attn(
qkv,
bias,
softmax_offset,
SequenceDescriptor.from_seqlens((q_seq_lens, kv_seq_lens)),
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
......@@ -900,6 +938,7 @@ def fused_attn_thd(
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
softmax_offset: Optional[jnp.ndarray] = None,
):
"""
Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor
......@@ -937,6 +976,7 @@ def fused_attn_thd(
output = _fused_attn(
qkv,
bias,
softmax_offset,
SequenceDescriptor.from_seqlens_and_offsets(
(q_seq_lens, kv_seq_lens), (q_seq_offsets, kv_seq_offsets)
),
......@@ -945,6 +985,7 @@ def fused_attn_thd(
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
softmax_type=AttnSoftmaxType.VANILLA_SOFTMAX,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
......@@ -957,15 +998,17 @@ def fused_attn_thd(
return output
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
softmax_offset: Optional[jnp.ndarray],
sequence_descriptor: SequenceDescriptor,
seed: Optional[jnp.ndarray],
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout,
softmax_type: AttnSoftmaxType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
......@@ -979,11 +1022,13 @@ def _fused_attn(
output, _ = _fused_attn_fwd_rule(
qkv,
bias,
softmax_offset,
sequence_descriptor,
seed,
attn_bias_type,
attn_mask_type,
qkv_layout,
softmax_type,
scaling_factor,
dropout_probability,
is_training,
......@@ -1000,11 +1045,13 @@ def _fused_attn(
def _fused_attn_fwd_rule(
qkv,
bias,
softmax_offset,
sequence_descriptor,
seed,
attn_bias_type,
attn_mask_type,
qkv_layout,
softmax_type,
scaling_factor,
dropout_probability,
is_training,
......@@ -1018,10 +1065,12 @@ def _fused_attn_fwd_rule(
output, softmax_aux, rng_state = tex.fused_attn_fwd(
qkv,
bias,
softmax_offset,
sequence_descriptor,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
......@@ -1041,6 +1090,7 @@ def _fused_attn_fwd_rule(
sequence_descriptor,
softmax_aux,
rng_state,
softmax_offset,
output,
)
......@@ -1049,6 +1099,7 @@ def _fused_attn_bwd_rule(
attn_bias_type,
attn_mask_type,
qkv_layout,
softmax_type,
scaling_factor,
dropout_probability,
is_training,
......@@ -1068,11 +1119,13 @@ def _fused_attn_bwd_rule(
sequence_descriptor,
softmax_aux,
rng_state,
softmax_offset,
output,
) = ctx
grad_qkv, grad_bias = tex.fused_attn_bwd(
grad_qkv, grad_bias, grad_softmax_offset = tex.fused_attn_bwd(
qkv,
bias,
softmax_offset,
softmax_aux,
rng_state,
output,
......@@ -1080,6 +1133,7 @@ def _fused_attn_bwd_rule(
sequence_descriptor,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
......@@ -1092,9 +1146,12 @@ def _fused_attn_bwd_rule(
)
if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None
if softmax_type != AttnSoftmaxType.LEARNABLE_SOFTMAX:
grad_softmax_offset = None
return (
grad_qkv,
grad_bias,
grad_softmax_offset,
None,
None,
)
......@@ -1111,6 +1168,7 @@ def fused_attn(
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout,
softmax_type: AttnSoftmaxType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
......@@ -1120,6 +1178,7 @@ def fused_attn(
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
context_checkpoint_name: str = "context",
softmax_offset: Optional[jnp.ndarray] = None,
):
"""
Perform cuDNN fused attention.
......@@ -1139,6 +1198,7 @@ def fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
......@@ -1153,6 +1213,9 @@ def fused_attn(
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape
[1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX.
If provided, this parameter will receive gradients during backpropagation.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
......@@ -1200,6 +1263,7 @@ def fused_attn(
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
......@@ -1208,15 +1272,18 @@ def fused_attn(
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
softmax_offset=softmax_offset,
)
output = _fused_attn(
qkv,
bias,
softmax_offset,
sequence_descriptor,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
softmax_type=softmax_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
......
......@@ -11,10 +11,11 @@ import jax
import jax.numpy as jnp
from jax import dtypes, ffi
from jax.sharding import PartitionSpec, NamedSharding
from .attention import AttnSoftmaxType
from .base import BasePrimitive, register_primitive
from .misc import get_padded_spec, check_valid_batch_dims
from ..softmax import SoftmaxType
from ..softmax import SoftmaxFusionType
__all__ = [
......@@ -32,7 +33,8 @@ __all__ = [
def is_softmax_kernel_available(
softmax_type: SoftmaxType,
softmax_fusion_type: SoftmaxFusionType,
softmax_type: AttnSoftmaxType,
batch: int,
heads: int,
q_seqlen: int,
......@@ -40,15 +42,18 @@ def is_softmax_kernel_available(
dtype: jnp.dtype,
):
"""check softmax available"""
if softmax_type is SoftmaxType.SCALED:
if softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
return False
if softmax_fusion_type is SoftmaxFusionType.SCALED:
return ScaledSoftmaxFwdPrimitive.is_kernel_available(
batch, heads, q_seqlen, k_seqlen, dtype
)
if softmax_type is SoftmaxType.SCALED_MASKED:
if softmax_fusion_type is SoftmaxFusionType.SCALED_MASKED:
return ScaledMaskedSoftmaxFwdPrimitive.is_kernel_available(
batch, heads, q_seqlen, k_seqlen, dtype
)
if softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
if softmax_fusion_type is SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
batch, heads, q_seqlen, k_seqlen, dtype
)
......@@ -792,26 +797,77 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
def jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float):
def jax_scaled_softmax(
logits: jnp.ndarray, scale_factor: float, softmax_offset: jnp.ndarray | float | None = None
):
"""
JAX based implementation of scaled softmax
"""
if softmax_offset is not None:
return jax_general_softmax(scale_factor * logits, offset=softmax_offset)
return jax.nn.softmax(scale_factor * logits)
def jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float):
def jax_scaled_masked_softmax(
logits: jnp.ndarray,
mask: jnp.ndarray,
scale_factor: float,
softmax_offset: jnp.ndarray | float | None = None,
):
"""
JAX based implementation of scaled and masked softmax
"""
if softmax_offset is not None:
return jax_general_softmax(logits * scale_factor, offset=softmax_offset, where=mask != 1)
return jax.nn.softmax(logits * scale_factor, where=mask != 1)
def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
def jax_scaled_upper_triang_masked_softmax(
logits: jnp.ndarray, scale_factor: float, softmax_offset: jnp.ndarray | float | None = None
):
"""
JAX based implementation of scaled and upper triangle masked softmax
"""
mask = 1 - jnp.tril(jnp.ones_like(logits))
return jax_scaled_masked_softmax(logits, mask, scale_factor)
return jax_scaled_masked_softmax(logits, mask, scale_factor, softmax_offset)
def jax_general_softmax(
x: jnp.ndarray,
axis: int = -1,
where: jnp.ndarray | None = None,
initial: jnp.ndarray = -jnp.inf,
offset: jnp.ndarray | float | None = None,
) -> jnp.ndarray:
"""
JAX based implementation of general softmax with optional masking and offset.
"""
# Compute max of x
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
if offset is not None:
# Cast offset to x.dtype to prevent type promotion
if isinstance(offset, (int, float)):
offset = jnp.array(offset, dtype=x.dtype)
else:
offset = offset.astype(x.dtype)
# Include offset in max: x_max = max(x_max, offset)
# This is equivalent to computing max over [x..., offset]
x_max = jnp.maximum(x_max, offset)
x_safe = x if where is None else jnp.where(where, x, initial)
unnormalized = jnp.exp(x_safe - x_max)
denominator = jnp.sum(unnormalized, axis, where=where, keepdims=True)
if offset is not None:
# Add exp(offset - x_max) to denominator
denominator = denominator + jnp.exp(offset - x_max)
result = unnormalized / denominator
if where is not None:
result = jnp.where(where, result, 0)
return result
def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
......
......@@ -108,28 +108,28 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t qk_head_dim, size_t v_head_dim,
int64_t window_size_left, int64_t window_size_right);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(
bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left,
int64_t window_size_right);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right);
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right);
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq,
int64_t window_size_left, int64_t window_size_right);
// GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);
......
......@@ -142,6 +142,11 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD)
.value("NVTE_THD", NVTE_QKV_Format::NVTE_THD);
pybind11::enum_<NVTE_Softmax_Type>(m, "NVTE_Softmax_Type", pybind11::module_local())
.value("NVTE_VANILLA_SOFTMAX", NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)
.value("NVTE_OFF_BY_ONE_SOFTMAX", NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX)
.value("NVTE_LEARNABLE_SOFTMAX", NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX);
pybind11::enum_<NVTE_Activation_Type>(m, "NVTE_Activation_Type", pybind11::module_local())
.value("GELU", NVTE_Activation_Type::GELU)
.value("GEGLU", NVTE_Activation_Type::GEGLU)
......
......@@ -7,6 +7,7 @@ Wrapper module for Transformer related layers with FP8 support.
from functools import reduce
import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType, Optional
import warnings
import numpy as np
import jax.numpy as jnp
......@@ -23,8 +24,9 @@ from ..layernorm import layernorm
from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp
from ..activation import activation
from ..softmax import softmax, SoftmaxType
from ..softmax import softmax, SoftmaxFusionType
from ..sharding import with_sharding_constraint_by_logical_axes
from ..attention import AttnSoftmaxType
from ..cpp_extensions import (
is_softmax_kernel_available,
jax_scaled_softmax,
......@@ -171,15 +173,20 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
----------
scale_factor : float, default = 1.0
Scalar for the input to softmax.
softmax_type : SoftmaxType, default = SoftmaxType.SCALED
softmax_fusion_type : SoftmaxFusionType, default = SoftmaxFusionType.SCALED
Indicate the type of softmax.
softmax_type : AttnSoftmaxType, default = AttnSoftmaxType.VANILLA_SOFTMAX
Indicate the type of softmax.
"""
scale_factor: float = 1.0
softmax_type: SoftmaxType = SoftmaxType.SCALED
softmax_fusion_type: SoftmaxFusionType = SoftmaxFusionType.SCALED
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
@nn.compact
def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
def __call__(
self, inputs: Array, mask: Array = None, bias: Array = None, softmax_offset: Array = None
) -> jnp.ndarray:
batch = inputs.shape[0]
heads = inputs.shape[1]
q_seqlen = inputs.shape[2]
......@@ -187,33 +194,52 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
input_dtype = inputs.dtype
logits = inputs
if softmax_offset is not None:
assert self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX
if self.softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
softmax_offset = 0.0
# use primitives
if is_softmax_kernel_available(
self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
self.softmax_fusion_type,
self.softmax_type,
batch,
heads,
q_seqlen,
k_seqlen,
input_dtype,
):
if bias is not None:
logits = logits + bias.astype(input_dtype)
mask_ = mask
if self.softmax_type is not SoftmaxType.SCALED_MASKED:
if self.softmax_fusion_type is not SoftmaxFusionType.SCALED_MASKED:
mask_ = None
outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
outputs = softmax(logits, mask_, self.scale_factor, self.softmax_fusion_type)
# use default jax based implementation
else:
warnings.warn(
"Using unfused JAX softmax implementation instead of TE fused primitives. ",
UserWarning,
stacklevel=2,
)
if bias is not None:
logits = logits + bias.astype(input_dtype)
if self.softmax_type is SoftmaxType.SCALED:
outputs = jax_scaled_softmax(logits, self.scale_factor)
elif self.softmax_type is SoftmaxType.SCALED_MASKED:
outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
if self.softmax_fusion_type is SoftmaxFusionType.SCALED:
outputs = jax_scaled_softmax(logits, self.scale_factor, softmax_offset)
elif self.softmax_fusion_type is SoftmaxFusionType.SCALED_MASKED:
outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor, softmax_offset)
elif self.softmax_fusion_type is SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
outputs = jax_scaled_upper_triang_masked_softmax(
logits, self.scale_factor, softmax_offset
)
else:
raise ValueError(
f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
" SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
f"Unsupported softmax fusion: {self.softmax_fusion_type}. softmax_fusion_type"
" must be [SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
)
assert input_dtype == outputs.dtype
return outputs
......
......@@ -23,11 +23,17 @@ from jax.ad_checkpoint import checkpoint_name
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor
from ..attention import (
AttnBiasType,
AttnMaskType,
AttnSoftmaxType,
QKVLayout,
SequenceDescriptor,
)
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
from ..attention import fused_attn
from ..attention import CPStrategy
from ..softmax import SoftmaxType
from ..softmax import SoftmaxFusionType
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
from ..sharding import with_sharding_constraint_by_logical_axes
......@@ -120,6 +126,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
window_size: Optional[Tuple[int, int]] = None
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
@nn.compact
def __call__(
......@@ -145,6 +152,22 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
input_dtype = query.dtype
# Infer number of attention heads from query shape
# query shape: [..., h, d] where h is num_attention_heads
num_attention_heads = query.shape[-2]
# Initialize softmax_offset for learnable softmax
# Note: OFF_BY_ONE_SOFTMAX is handled internally by the Softmax module
softmax_offset = None
if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
# For learnable softmax, create a learnable parameter with proper sharding and shape (1, h, 1, 1)
softmax_offset = self.param(
"softmax_offset",
nn.with_logical_partitioning(nn.initializers.zeros, (None, HEAD_AXES, None, None)),
(1, num_attention_heads, 1, 1),
jnp.float32,
)
if self.scale_factor is None:
scale_factor = 1.0 / sqrt(query.shape[-1])
else:
......@@ -213,8 +236,8 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
new_mask = jnp.where(original_mask == 0, swa_mask, original_mask)
return new_mask
def convert_to_softmax_type(attn_mask_type, mask):
"""Convert the attn_mask_type to SoftmaxType"""
def convert_to_softmax_fusion_type(attn_mask_type, mask):
"""Convert the attn_mask_type to SoftmaxFusionType"""
# mask is ignored for no_mask and causal_mask without sliding window
if attn_mask_type == AttnMaskType.NO_MASK:
mask = None
......@@ -224,21 +247,23 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
mask = apply_swa_mask(mask)
# Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
if mask is not None:
return SoftmaxType.SCALED_MASKED, mask
return SoftmaxFusionType.SCALED_MASKED, mask
if attn_mask_type is AttnMaskType.CAUSAL_MASK:
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
return SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, mask
if attn_mask_type is AttnMaskType.NO_MASK:
return SoftmaxType.SCALED, mask
return SoftmaxFusionType.SCALED, mask
raise ValueError(
f"Unsupported {attn_mask_type=}, supported attn_mask_type="
"{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}"
)
softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask)
softmax_fusion_type, mask = convert_to_softmax_fusion_type(self.attn_mask_type, mask)
attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)(
attn_weights, mask, bias
).astype(input_dtype)
attn_weights = Softmax(
softmax_fusion_type=softmax_fusion_type,
softmax_type=self.softmax_type,
scale_factor=fused_scale_factor,
)(attn_weights, mask, bias, softmax_offset=softmax_offset).astype(input_dtype)
if is_gqa:
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
......@@ -279,6 +304,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
context_parallel_axis: str = ""
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT
context_checkpoint_name: str = "context"
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
@nn.compact
def __call__(
......@@ -303,6 +329,17 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scale_factor = self.scale_factor
del self.scale_factor
num_attention_heads = query.shape[-2]
softmax_offset = None
if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
# For learnable softmax, create a learnable parameter with proper sharding and shape (1, h, 1, 1)
softmax_offset = self.param(
"softmax_offset",
nn.with_logical_partitioning(nn.initializers.zeros, (None, HEAD_AXES, None, None)),
(1, num_attention_heads, 1, 1),
jnp.float32,
)
if self.qkv_layout.is_qkvpacked():
"""qkvpacked format, treat
query: qkvpacked tensor, shape = [..., 3, h, d]
......@@ -320,6 +357,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout,
softmax_type=self.softmax_type,
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic,
......@@ -329,6 +367,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name,
softmax_offset=softmax_offset,
)
elif self.qkv_layout.is_kvpacked():
"""kvpacked format, treat
......@@ -348,6 +387,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout,
softmax_type=self.softmax_type,
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic,
......@@ -357,6 +397,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name,
softmax_offset=softmax_offset,
)
elif self.qkv_layout.is_separate():
if self.transpose_batch_sequence:
......@@ -371,6 +412,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout,
softmax_type=self.softmax_type,
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic,
......@@ -380,6 +422,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name,
softmax_offset=softmax_offset,
)
else:
raise ValueError(f"Unsupported {self.qkv_layout=}.")
......@@ -514,6 +557,17 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
context_parallel_axis (str): The name of the context parallel axis.
context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING.
context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Optimization parameters
-----------------------
......@@ -539,6 +593,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
context_parallel_axis: str = ""
context_parallel_strategy: str = "DEFAULT"
context_checkpoint_name: str = "context"
softmax_type: str = "vanilla"
@nn.compact
def __call__(
......@@ -595,6 +650,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_bias_type = AttnBiasType[self.attn_bias_type.upper()]
attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
qkv_layout = QKVLayout[self.qkv_layout.upper()]
softmax_type = AttnSoftmaxType.from_str(self.softmax_type)
del self.attn_bias_type, self.attn_mask_type, self.qkv_layout
if attn_bias_type == AttnBiasType.NO_BIAS:
......@@ -626,6 +682,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
qkv_layout,
attn_bias_type,
attn_mask_type,
softmax_type,
self.attention_dropout,
self.num_attention_heads,
self.num_gqa_groups,
......@@ -702,6 +759,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
window_size=self.window_size,
softmax_type=softmax_type,
)(
query,
key,
......@@ -726,6 +784,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name,
softmax_type=softmax_type,
)(
query,
key,
......@@ -1005,6 +1064,17 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
Deprecated. Please refer `fuse_qkv_params`
window_size: Optional[Tuple[int, int]], default = None
Sliding window size. Default value is no sliding window.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
"""
head_dim: int
......@@ -1036,6 +1106,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
scaled_query_init: bool = True
float32_logits: bool = False
window_size: Optional[Tuple[int, int]] = None
softmax_type: str = "vanilla"
# Deprecated parameters
num_heads: Optional[int] = None
......@@ -1440,6 +1511,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
window_size=self.window_size,
softmax_type=self.softmax_type,
)(*dpa_args, mask, bias, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
......@@ -1721,6 +1793,18 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Whether to enable sequence parallelism to operations except dot.
window_size: Optional[Tuple[int, int]], default = None
Sliding window size. Default value is no sliding window.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
Softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Only supported for fused attention backend.
Optimization parameters
-----------------------
......@@ -1786,6 +1870,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
scale_attn_logits: bool = False
scaled_query_init: bool = True
window_size: Optional[Tuple[int, int]] = None
softmax_type: str = "vanilla"
def __post_init__(self):
if self.mha_kernel_init is None:
......@@ -1946,6 +2031,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_init=self.bias_init,
name=mha_name,
window_size=self.window_size,
softmax_type=self.softmax_type,
)(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode)
def hidden_dropout(x, deterministic):
......@@ -2024,6 +2110,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_init=self.bias_init,
name="encoder_decoder_attention",
window_size=self.window_size,
softmax_type=self.softmax_type,
)(x, encoded, encoder_decoder_mask, deterministic=deterministic)
y = with_sharding_constraint_by_logical_axes(
......
......@@ -12,8 +12,8 @@ import jax.numpy as jnp
from . import cpp_extensions as tex
class SoftmaxType(Enum):
"""SoftmaxType."""
class SoftmaxFusionType(Enum):
"""SoftmaxFusionType."""
SCALED = "scaled"
SCALED_MASKED = "scaled_masked"
......@@ -24,27 +24,27 @@ def softmax(
logits: jnp.ndarray,
mask: Optional[jnp.ndarray] = None,
scale_factor: Optional[float] = 1.0,
softmax_type: Optional[SoftmaxType] = SoftmaxType.SCALED,
softmax_fusion_type: Optional[SoftmaxFusionType] = SoftmaxFusionType.SCALED,
):
"""
Softmax wrapper
"""
output = _softmax(logits, mask, scale_factor, softmax_type)
output = _softmax(logits, mask, scale_factor, softmax_fusion_type)
return output
@partial(jax.custom_vjp, nondiff_argnums=(2, 3))
def _softmax(logits, mask, scale_factor, softmax_type):
def _softmax(logits, mask, scale_factor, softmax_fusion_type):
output, _ = _softmax_fwd_rule(logits, mask, scale_factor, softmax_type)
output, _ = _softmax_fwd_rule(logits, mask, scale_factor, softmax_fusion_type)
return output
def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type):
if softmax_type is SoftmaxType.SCALED_MASKED:
def _softmax_fwd_rule(logits, mask, scale_factor, softmax_fusion_type):
if softmax_fusion_type is SoftmaxFusionType.SCALED_MASKED:
assert mask is not None
output = tex.scaled_masked_softmax_fwd(logits, mask, scale_factor)
elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
elif softmax_fusion_type is SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
output = tex.scaled_upper_triang_masked_softmax_fwd(logits, scale_factor)
else:
output = tex.scaled_softmax_fwd(logits, scale_factor)
......@@ -52,12 +52,12 @@ def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type):
return output, (output, logits, mask)
def _softmax_bwd_rule(scale_factor, softmax_type, ctx, dz):
def _softmax_bwd_rule(scale_factor, softmax_fusion_type, ctx, dz):
(softmax_output, logits, mask) = ctx
if softmax_type is SoftmaxType.SCALED_MASKED:
if softmax_fusion_type is SoftmaxFusionType.SCALED_MASKED:
dgrad = tex.scaled_masked_softmax_bwd(dz, softmax_output, logits, mask, scale_factor)
elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
elif softmax_fusion_type is SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
dgrad = tex.scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, logits, scale_factor)
else:
dgrad = tex.scaled_softmax_bwd(dz, softmax_output, logits, scale_factor)
......
......@@ -156,7 +156,9 @@ class FusedScaleMaskSoftmax(nn.Module):
softmax_in_fp32: bool = True,
) -> None:
super().__init__()
self.scaled_masked_softmax_fusion = bool(int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1")))
self.scaled_masked_softmax_fusion_type = bool(
int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1"))
)
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
......@@ -189,7 +191,7 @@ class FusedScaleMaskSoftmax(nn.Module):
"""Check FusedScaleMaskSoftmax kernel availability based on size"""
attn_batches = b * np
if not self.scaled_masked_softmax_fusion:
if not self.scaled_masked_softmax_fusion_type:
return False # user doesn't want to fuse
if not self.input_in_float16:
return False # input must be fp16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment