Unverified Commit 8e672ff0 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Refactor fused attention (#711)



* Remove unused headers
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Unify the fused attn workspace size cpp code
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Reduce the skipped cases
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Rename self/cross attention to qkvpacked/kvpacked
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update attention mask docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine the attn mask implementations
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent b855656b
...@@ -16,7 +16,7 @@ from distributed_test_base import compare_ops ...@@ -16,7 +16,7 @@ from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
DTYPES = [jnp.float16, jnp.bfloat16] DTYPES = [jnp.float16, jnp.bfloat16]
...@@ -86,15 +86,15 @@ class TestDistributedSelfAttn: ...@@ -86,15 +86,15 @@ class TestDistributedSelfAttn:
def target_func(qkv, bias, mask): def target_func(qkv, bias, mask):
return jnp.mean( return jnp.mean(
self_fused_attn(qkv, fused_attn_qkvpacked(qkv,
bias, bias,
mask, mask,
None, None,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_prob, dropout_probability=dropout_prob,
is_training=is_training)) is_training=is_training))
def ref_func(qkv, bias, mask): def ref_func(qkv, bias, mask):
query, key, value = jnp.split(qkv, [1, 2], axis=-3) query, key, value = jnp.split(qkv, [1, 2], axis=-3)
...@@ -192,16 +192,16 @@ class TestDistributedCrossAttn: ...@@ -192,16 +192,16 @@ class TestDistributedCrossAttn:
def target_func(q, kv, mask): def target_func(q, kv, mask):
return jnp.mean( return jnp.mean(
cross_fused_attn(q, fused_attn_kvpacked(q,
kv, kv,
None, None,
mask, mask,
None, None,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_prob, dropout_probability=dropout_prob,
is_training=is_training)) is_training=is_training))
def ref_func(query, kv, mask): def ref_func(query, kv, mask):
key, value = jnp.split(kv, [1], axis=-3) key, value = jnp.split(kv, [1], axis=-3)
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Tests for fused attention""" """Tests for fused attention"""
import sys
from enum import Enum from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
...@@ -21,7 +19,7 @@ from jax import value_and_grad, jit ...@@ -21,7 +19,7 @@ from jax import value_and_grad, jit
from jax.typing import ArrayLike, DTypeLike from jax.typing import ArrayLike, DTypeLike
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn, fused_attn from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn
from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine_jax import NVTE_Fused_Attn_Backend from transformer_engine_jax import NVTE_Fused_Attn_Backend
...@@ -144,18 +142,22 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng ...@@ -144,18 +142,22 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
case QKVLayout.BS3HD: case QKVLayout.BS3HD:
query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value]) query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
qkv = jnp.concatenate((query, key, value), axis=-3) qkv = jnp.concatenate((query, key, value), axis=-3)
return self_fused_attn(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype) return fused_attn_qkvpacked(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
case QKVLayout.BSHD_BS2HD: case QKVLayout.BSHD_BS2HD:
key, value = map(partial(jnp.expand_dims, axis=-3), [key, value]) key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
kv = jnp.concatenate((key, value), axis=-3) kv = jnp.concatenate((key, value), axis=-3)
return cross_fused_attn(query, kv, bias, mask, dropout_rng, return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng,
**kwargs).astype(query.dtype) **kwargs).astype(query.dtype)
case QKVLayout.BSHD_BSHD_BSHD: case QKVLayout.BSHD_BSHD_BSHD:
return fused_attn(query, key, value, bias, mask, dropout_rng, return fused_attn(query, key, value, bias, mask, dropout_rng,
**kwargs).astype(query.dtype) **kwargs).astype(query.dtype)
class BiasShape(Enum): class BiasShape(Enum):
"""
Enum class to represent the different bias shapes used in the fused attention.
"""
BIAS_1HSS = '1HSS' BIAS_1HSS = '1HSS'
BIAS_B1SS = 'B1SS' BIAS_B1SS = 'B1SS'
BIAS_BHSS = 'BHSS' BIAS_BHSS = 'BHSS'
...@@ -188,17 +190,16 @@ class FusedAttnRunner: ...@@ -188,17 +190,16 @@ class FusedAttnRunner:
if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv: if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.") pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")
self.backend = FusedAttnHelper( self.backend = FusedAttnHelper(self.dtype, self.dtype, self.qkv_layout.value,
self.dtype, self.dtype, self.qkv_layout.value, self.attn_bias_type.value, self.attn_bias_type.value, self.attn_mask_type.value,
self.attn_mask_type.value, self.dropout_prob, self.num_heads_q, self.num_heads_kv, self.dropout_prob, self.num_heads_q, self.num_heads_kv,
self.max_seqlen_q, self.max_seqlen_kv, self.head_dim).get_fused_attn_backend() self.max_seqlen_q, self.max_seqlen_kv,
self.head_dim).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("Unsupported inputs combination or device compute capability.") pytest.skip("Unsupported inputs combination or device compute capability.")
if self.bias_shape != BiasShape.BIAS_1HSS: if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
if self.attn_bias_type != AttnBiasType.POST_SCALE_BIAS: if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pytest.skip("B1SS, BHSS and 11SS bias shapes require POST_SCALE_BIAS.")
elif self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for " pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
"AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.") "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.")
elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
...@@ -213,7 +214,9 @@ class FusedAttnRunner: ...@@ -213,7 +214,9 @@ class FusedAttnRunner:
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim) q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim) k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim)
if self.bias_shape == BiasShape.BIAS_1HSS: if self.attn_bias_type == AttnBiasType.NO_BIAS:
bias_shape = None
elif self.bias_shape == BiasShape.BIAS_1HSS:
bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv) bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape.BIAS_B1SS: elif self.bias_shape == BiasShape.BIAS_B1SS:
bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv) bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
...@@ -222,7 +225,7 @@ class FusedAttnRunner: ...@@ -222,7 +225,7 @@ class FusedAttnRunner:
elif self.bias_shape == BiasShape.BIAS_11SS: elif self.bias_shape == BiasShape.BIAS_11SS:
bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv) bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
else: else:
pytest.xfail("PyTest attempted to use an unrecognized bias layout!") pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.) self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.)
self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.) self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.)
...@@ -237,7 +240,7 @@ class FusedAttnRunner: ...@@ -237,7 +240,7 @@ class FusedAttnRunner:
cudnn_neg_inf = -2.**27. if self.dtype == jnp.bfloat16 else -2.**15. cudnn_neg_inf = -2.**27. if self.dtype == jnp.bfloat16 else -2.**15.
self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype) self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
max_id = min(self.max_seqlen_q, self.max_seqlen_kv) max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences
seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist() seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
for i in range(1, len(seq_id)): for i in range(1, len(seq_id)):
self.bias = \ self.bias = \
...@@ -327,8 +330,8 @@ class FusedAttnRunner: ...@@ -327,8 +330,8 @@ class FusedAttnRunner:
**kwargs), arg_nums)) **kwargs), arg_nums))
jitted_reference = jit( jitted_reference = jit(
value_and_grad( value_and_grad(
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
**kwargs), arg_nums)) arg_nums))
primitive_out, primitive_dgrad = jitted_primitive(*args) primitive_out, primitive_dgrad = jitted_primitive(*args)
reference_out, reference_dgrad = jitted_reference(*args) reference_out, reference_dgrad = jitted_reference(*args)
...@@ -361,10 +364,10 @@ class FusedAttnRunner: ...@@ -361,10 +364,10 @@ class FusedAttnRunner:
primitive_dbias = jnp.float32(primitive_dgrad[3]) primitive_dbias = jnp.float32(primitive_dgrad[3])
reference_dbias = jnp.float32(reference_dgrad[3]) reference_dbias = jnp.float32(reference_dgrad[3])
assert_allclose( assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:], jnp.zeros_like(primitive_dbias[..., self.valid_len_q:,
jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:]), self.valid_len_kv:]),
dtype=self.dtype) dtype=self.dtype)
# dbias padded part # dbias padded part
assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:], assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
...@@ -376,15 +379,13 @@ class FusedAttnRunner: ...@@ -376,15 +379,13 @@ class FusedAttnRunner:
reference_dbias[..., :self.valid_len_q, :self.valid_len_kv], reference_dbias[..., :self.valid_len_q, :self.valid_len_kv],
dtype=self.dtype) dtype=self.dtype)
@pytest.mark.parametrize('bias_shape', [
pytest.param(BiasShape.BIAS_1HSS, id='1-H-S-S'), @pytest.mark.parametrize('attn_bias_type, bias_shape', [
pytest.param(BiasShape.BIAS_B1SS, id='B-1-S-S'), pytest.param(AttnBiasType.NO_BIAS, None, id='NO_BIAS'),
pytest.param(BiasShape.BIAS_BHSS, id='B-H-S-S'), pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id='POST_SCALE_BIAS-1HSS'),
pytest.param(BiasShape.BIAS_11SS, id='1-1-S-S'), pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id='POST_SCALE_BIAS-B1SS'),
]) pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id='POST_SCALE_BIAS-BHSS'),
@pytest.mark.parametrize('attn_bias_type', [ pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id='POST_SCALE_BIAS-11SS'),
pytest.param(AttnBiasType.NO_BIAS, id='NO_BIAS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, id='POST_SCALE_BIAS'),
]) ])
@pytest.mark.parametrize('attn_mask_type', [ @pytest.mark.parametrize('attn_mask_type', [
pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'), pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'),
...@@ -399,31 +400,32 @@ class FusedAttnRunner: ...@@ -399,31 +400,32 @@ class FusedAttnRunner:
]) ])
@pytest.mark.parametrize('dtype', [ @pytest.mark.parametrize('dtype', [
pytest.param(jnp.bfloat16, id="BF16"), pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16") pytest.param(jnp.float16, id="FP16"),
]) ])
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d',[ @pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d', [
pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'), pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'),
pytest.param( 4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'), pytest.param(4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'),
pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'), pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'),
pytest.param( 4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'), pytest.param(4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'),
pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'), pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'),
pytest.param( 4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA') pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA'),
]) ])
@pytest.mark.parametrize('dropout_prob', [ @pytest.mark.parametrize('dropout_prob', [
pytest.param(0.0, id="DROP_0.0"), pytest.param(0.0, id="DROP_0.0"),
pytest.param(0.1, id="DROP_0.1") pytest.param(0.1, id="DROP_0.1"),
])
@pytest.mark.parametrize('is_training', [
pytest.param(True, id='TRAINING'),
pytest.param(False, id='INFERENCE'),
]) ])
class TestFusedAttn: class TestFusedAttn:
""" """
Fused attention tester Fused attention tester
""" """
@staticmethod @staticmethod
def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, @pytest.mark.parametrize('is_training', [
dropout_prob, dtype, is_training, qkv_layout, bias_shape): pytest.param(True, id='TRAINING'),
pytest.param(False, id='INFERENCE'),
])
def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
dtype, is_training, qkv_layout, bias_shape):
""" """
Test forward with parameterized configs Test forward with parameterized configs
""" """
...@@ -432,13 +434,11 @@ class TestFusedAttn: ...@@ -432,13 +434,11 @@ class TestFusedAttn:
runner.test_forward() runner.test_forward()
@staticmethod @staticmethod
def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
dropout_prob, dtype, is_training, qkv_layout, bias_shape): dtype, qkv_layout, bias_shape):
""" """
Test backward with parameterized configs Test backward with parameterized configs
""" """
if not is_training:
pytest.skip("Backward pass does not support inference.")
runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, True, qkv_layout, bias_shape) dropout_prob, dtype, True, qkv_layout, bias_shape)
runner.test_backward() runner.test_backward()
...@@ -449,6 +449,7 @@ class TestDecoderLayer: ...@@ -449,6 +449,7 @@ class TestDecoderLayer:
hidden_dropout_dims=(sequence_dim,), hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,), intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER, layer_type=TransformerLayerType.DECODER,
self_attn_mask_type='padding_causal',
dtype=dtype, dtype=dtype,
**te_layer_attrs) **te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs, ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
...@@ -497,6 +498,7 @@ class TestDecoderLayer: ...@@ -497,6 +498,7 @@ class TestDecoderLayer:
hidden_dropout_dims=(sequence_dim,), hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,), intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER, layer_type=TransformerLayerType.DECODER,
self_attn_mask_type='padding_causal',
dtype=dtype, dtype=dtype,
**te_layer_attrs) **te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs, ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
......
...@@ -730,8 +730,13 @@ class TestDotProductAttn(TestLayer): ...@@ -730,8 +730,13 @@ class TestDotProductAttn(TestLayer):
def input_getter(self, shape, dtype): def input_getter(self, shape, dtype):
key = jax.random.PRNGKey(seed=1234) key = jax.random.PRNGKey(seed=1234)
q_key, k_key, v_key = jax.random.split(key, 3) q_key, k_key, v_key = jax.random.split(key, 3)
return list(map(partial(jax.random.normal, shape=shape, dtype=dtype), b, s, *_ = shape
[q_key, k_key, v_key])) if self.attrs[DotProductAttnAttr.TRANSPOSE_BS]:
b, s = s, b
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), mask
]
def get_layer_name(self): def get_layer_name(self):
return 'dot_product_attn' return 'dot_product_attn'
...@@ -765,6 +770,7 @@ class TestDotProductAttn(TestLayer): ...@@ -765,6 +770,7 @@ class TestDotProductAttn(TestLayer):
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS) @pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
...@@ -853,9 +859,11 @@ class MultiHeadAttnAttr: ...@@ -853,9 +859,11 @@ class MultiHeadAttnAttr:
class TestMultiHeadAttn(TestLayer): class TestMultiHeadAttn(TestLayer):
def input_getter(self, shape, dtype): def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234) key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape, q_key, kv_key = jax.random.split(key, 2)
dtype), jax.random.normal(data_key, shape, dtype)) s, b, *_ = shape
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask]
def get_layer_name(self): def get_layer_name(self):
return 'multi_head_attn' return 'multi_head_attn'
...@@ -1183,9 +1191,15 @@ class TransformerLayerAttr: ...@@ -1183,9 +1191,15 @@ class TransformerLayerAttr:
class TestTransformer(TestLayer): class TestTransformer(TestLayer):
def input_getter(self, shape, dtype): def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234) key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape, q_key, kv_key = jax.random.split(key, 2)
dtype), jax.random.normal(data_key, shape, dtype)) b, s, *_ = shape
if self.attrs[TransformerLayerAttr.TRANSPOSE_BS]:
b, s = s, b
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask, mask
]
def get_layer_name(self): def get_layer_name(self):
return 'transformerlayer' return 'transformerlayer'
...@@ -1277,6 +1291,7 @@ class TestTransformer(TestLayer): ...@@ -1277,6 +1291,7 @@ class TestTransformer(TestLayer):
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS) @pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
...@@ -1292,7 +1307,7 @@ class TestTransformer(TestLayer): ...@@ -1292,7 +1307,7 @@ class TestTransformer(TestLayer):
fp8_format, fp8_format,
rtol=1e-05, rtol=1e-05,
atol=1e-08): atol=1e-08):
self.attrs = attrs
ds = DelayedScaling(fp8_format=fp8_format) ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
......
...@@ -1368,14 +1368,13 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -1368,14 +1368,13 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
@staticmethod @staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return ScaledSoftmaxFwdPrimitive.forward_infer_sharding_from_operands( return ScaledSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos scale_factor, mesh, arg_infos, result_infos)
)
@staticmethod @staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos): def partition(scale_factor, mesh, arg_infos, result_infos):
return ScaledSoftmaxFwdPrimitive.forward_partition( return ScaledSoftmaxFwdPrimitive.forward_partition(ScaledSoftmaxFwdPrimitive.impl,
ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos scale_factor, mesh, arg_infos,
) result_infos)
register_primitive(ScaledSoftmaxFwdPrimitive) register_primitive(ScaledSoftmaxFwdPrimitive)
...@@ -1444,14 +1443,13 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -1444,14 +1443,13 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
@staticmethod @staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return ScaledSoftmaxBwdPrimitive.backward_infer_sharding_from_operands( return ScaledSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos scale_factor, mesh, arg_infos, result_infos)
)
@staticmethod @staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos): def partition(scale_factor, mesh, arg_infos, result_infos):
return ScaledSoftmaxBwdPrimitive.backward_partition( return ScaledSoftmaxBwdPrimitive.backward_partition(ScaledSoftmaxBwdPrimitive.impl,
ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos scale_factor, mesh, arg_infos,
) result_infos)
register_primitive(ScaledSoftmaxBwdPrimitive) register_primitive(ScaledSoftmaxBwdPrimitive)
...@@ -1581,14 +1579,12 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -1581,14 +1579,12 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
@staticmethod @staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return ScaledMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands( return ScaledMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos,result_infos scale_factor, mesh, arg_infos, result_infos)
)
@staticmethod @staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos): def partition(scale_factor, mesh, arg_infos, result_infos):
return ScaledMaskedSoftmaxFwdPrimitive.backward_partition( return ScaledMaskedSoftmaxFwdPrimitive.backward_partition(
ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos)
)
register_primitive(ScaledMaskedSoftmaxFwdPrimitive) register_primitive(ScaledMaskedSoftmaxFwdPrimitive)
...@@ -1660,14 +1656,12 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -1660,14 +1656,12 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
@staticmethod @staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return ScaledMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands( return ScaledMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos scale_factor, mesh, arg_infos, result_infos)
)
@staticmethod @staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos): def partition(scale_factor, mesh, arg_infos, result_infos):
return ScaledMaskedSoftmaxBwdPrimitive.backward_partition( return ScaledMaskedSoftmaxBwdPrimitive.backward_partition(
ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos)
)
register_primitive(ScaledMaskedSoftmaxBwdPrimitive) register_primitive(ScaledMaskedSoftmaxBwdPrimitive)
...@@ -1749,15 +1743,13 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -1749,15 +1743,13 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
@staticmethod @staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands( return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos scale_factor, mesh, arg_infos, result_infos)
)
@staticmethod @staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos): def partition(scale_factor, mesh, arg_infos, result_infos):
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_partition( return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_partition(
ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos,
arg_infos, result_infos result_infos)
)
register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive) register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)
...@@ -1829,15 +1821,13 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -1829,15 +1821,13 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
@staticmethod @staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands( return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos scale_factor, mesh, arg_infos, result_infos)
)
@staticmethod @staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos): def partition(scale_factor, mesh, arg_infos, result_infos):
return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_partition( return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_partition(
ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos,
arg_infos, result_infos result_infos)
)
register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
...@@ -1859,16 +1849,16 @@ class FusedAttnHelper: ...@@ -1859,16 +1849,16 @@ class FusedAttnHelper:
Helper for the fused attention backend Helper for the fused attention backend
""" """
q_type: jnp.dtype q_dtype: jnp.dtype
kv_type: jnp.dtype kv_dtype: jnp.dtype
qkv_layout: NVTE_QKV_Layout qkv_layout: NVTE_QKV_Layout
attn_bias_type: NVTE_Bias_Type attn_bias_type: NVTE_Bias_Type
attn_mask_type: NVTE_Mask_Type attn_mask_type: NVTE_Mask_Type
dropout_probability: float dropout_probability: float
num_heads_q: int q_num_heads: int
num_heads_kv: int kv_num_heads: int
max_seqlen_q: int q_max_seqlen: int
max_seqlen_kv: int kv_max_seqlen: int
head_dim: int head_dim: int
def is_fused_attn_kernel_available(self): def is_fused_attn_kernel_available(self):
...@@ -1878,11 +1868,38 @@ class FusedAttnHelper: ...@@ -1878,11 +1868,38 @@ class FusedAttnHelper:
def get_fused_attn_backend(self): def get_fused_attn_backend(self):
"""Get the fused attention kernel backend""" """Get the fused attention kernel backend"""
return transformer_engine_jax.get_fused_attn_backend( return transformer_engine_jax.get_fused_attn_backend(
jax_dtype_to_te_dtype(self.q_type), jax_dtype_to_te_dtype(self.kv_type), jax_dtype_to_te_dtype(self.q_dtype), jax_dtype_to_te_dtype(self.kv_dtype),
self.qkv_layout, self.attn_bias_type, self.attn_mask_type, self.dropout_probability, self.qkv_layout, self.attn_bias_type, self.attn_mask_type, self.dropout_probability,
self.num_heads_q, self.num_heads_kv, self.max_seqlen_q, self.max_seqlen_kv, self.q_num_heads, self.kv_num_heads, self.q_max_seqlen, self.kv_max_seqlen,
self.head_dim) self.head_dim)
@staticmethod
def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout):
"""Parse qkv aval"""
match qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD:
*q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape
kv_batch_shape = q_batch_shape
kv_max_seqlen = q_max_seqlen
num_gqa_groups = attn_heads
kv_head_dim = q_head_dim
assert nqkv == 3
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape
assert nkv == 2
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
assert k_aval.shape == v_aval.shape
case _:
raise ValueError(f"Unexpected {qkv_layout=}")
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert q_aval.dtype == k_aval.dtype == v_aval.dtype
return (q_batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, q_head_dim)
@dataclass(frozen=True) @dataclass(frozen=True)
class _FusedAttnRNGStateChecker: class _FusedAttnRNGStateChecker:
...@@ -1933,46 +1950,50 @@ def generate_cu_seqlen(actual_seqlen): ...@@ -1933,46 +1950,50 @@ def generate_cu_seqlen(actual_seqlen):
return cu_seqlen return cu_seqlen
class SelfFusedAttnFwdPrimitive(BasePrimitive): class FusedAttnFwdPrimitive(BasePrimitive):
""" """
Self Fused Attention Forward Primitive Fused Attention Forward Primitive
""" """
name = "te_self_fused_attn_forward" name = "te_fused_attn_forward"
multiple_results = True multiple_results = True
impl_static_args = (4, 5, 6, 7, 8) impl_static_args = (7, 8, 9, 10, 11, 12)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract(qkv_aval, bias_aval, seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, def abstract(q_aval, k_aval, v_aval, bias_aval, q_seqlen_or_cu_seqlen_aval,
attn_mask_type, scaling_factor, dropout_probability, is_training): kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type,
qkv_layout, scaling_factor, dropout_probability, is_training):
""" """
Self fused attention fwd inner primitive abstract Fused attention fwd abstract
""" """
# outer_primitve is squeezed_mask, inner_primitive is cu_seqlen q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
del seqlen_or_cu_seqlen_aval k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype) v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
*input_batch_shape, max_seqlen, nqkv, attn_heads, head_dim = qkv_aval.shape bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
assert nqkv == 3 assert q_dtype == k_dtype == v_dtype == bias_dtype
assert qkv_aval.dtype == bias_aval.dtype assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
output_shape = (*input_batch_shape, max_seqlen, attn_heads, head_dim) output_shape = (*batch_shape, q_max_seqlen, attn_heads, head_dim)
out_aval = qkv_aval.update(shape=output_shape, dtype=qkv_dtype) out_aval = q_aval.update(shape=output_shape, dtype=q_dtype)
# backend determines the softmax buffer shape/dtype # backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type, backend = FusedAttnHelper(q_dtype, k_dtype, qkv_layout, attn_bias_type, attn_mask_type,
attn_mask_type, dropout_probability, attn_heads, attn_heads, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen,
max_seqlen, max_seqlen, head_dim).get_fused_attn_backend() kv_max_seqlen, head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_shape = (*input_batch_shape, attn_heads, max_seqlen, max_seqlen) softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
softmax_dtype = qkv_dtype softmax_dtype = q_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_shape = (*input_batch_shape, attn_heads, max_seqlen, 1) softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else: else:
raise ValueError(f'Unsupported {backend=}') raise ValueError(f'Unsupported {backend=}')
softmax_aux_aval = qkv_aval.update(shape=softmax_shape, dtype=softmax_dtype) softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)
# JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
# 32-bit unsigned int to get the buffer size we need in the C++ kernel # 32-bit unsigned int to get the buffer size we need in the C++ kernel
...@@ -1990,32 +2011,32 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -1990,32 +2011,32 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
# prepare for the active fused-attn backend # prepare for the active fused-attn backend
input_batch = reduce(operator.mul, input_batch_shape) input_batch = reduce(operator.mul, batch_shape)
wkspace_info = transformer_engine_jax.get_self_fused_attn_fwd_workspace_sizes( wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes(
input_batch, bias_batch, max_seqlen, attn_heads, bias_heads, head_dim, input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, bias_heads, head_dim, scaling_factor, dropout_probability, attn_bias_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training) attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), is_training)
wkspace_aval = qkv_aval.update(shape=wkspace_info[0], wkspace_aval = q_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1])) dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval
@staticmethod @staticmethod
def outer_abstract(*args, **kwargs): def outer_abstract(*args, **kwargs):
""" """
Self fused attention fwd outer primitive abstract Fused attention fwd outer primitive abstract
""" """
out_aval, softmax_aux_aval, rng_state_aval, _ = \ out_aval, softmax_aux_aval, rng_state_aval, _ = \
SelfFusedAttnFwdPrimitive.abstract(*args, **kwargs) FusedAttnFwdPrimitive.abstract(*args, **kwargs)
return out_aval, softmax_aux_aval, rng_state_aval return out_aval, softmax_aux_aval, rng_state_aval
@staticmethod @staticmethod
def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, scaling_factor, def lowering(ctx, q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type,
dropout_probability, is_training): attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training):
""" """
Self fused attention fwd lowering rules Fused attention fwd lowering rules
""" """
operands = [qkv, bias, cu_seqlen, seed] operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed]
operand_shapes = map(lambda x: x.type.shape, operands) operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [ out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
...@@ -2023,9 +2044,12 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2023,9 +2044,12 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
] ]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
qkv_aval, bias_aval, *_ = ctx.avals_in q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
*input_batch_shape, max_seqlen, _, attn_heads, head_dim = qkv_aval.shape
input_batch = reduce(operator.mul, input_batch_shape) batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
input_batch = reduce(operator.mul, batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0 bias_batch = bias_heads = 0
...@@ -2036,137 +2060,137 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2036,137 +2060,137 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
wkspace_aval = ctx.avals_out[-1] wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch, bias_batch, max_seqlen, max_seqlen, input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
attn_heads, attn_heads, bias_heads, head_dim, wkspace_aval.size, bias_heads, head_dim, wkspace_aval.size, scaling_factor, dropout_probability,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, attn_bias_type, attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(qkv_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), is_training)
is_training)
out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out return out
@staticmethod @staticmethod
def impl(qkv, bias, seqlen, seed, attn_bias_type, attn_mask_type, scaling_factor, def impl(q, k, v, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type, qkv_layout,
dropout_probability, is_training): scaling_factor, dropout_probability, is_training):
assert SelfFusedAttnFwdPrimitive.inner_primitive is not None assert FusedAttnFwdPrimitive.inner_primitive is not None
cu_seqlen = generate_cu_seqlen(seqlen) q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
output, softmax_aux, rng_state, _ = SelfFusedAttnFwdPrimitive.inner_primitive.bind( output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind(
qkv, q,
k,
v,
bias, bias,
cu_seqlen, q_cu_seqlen,
kv_cu_seqlen,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return output, softmax_aux, rng_state return output, softmax_aux, rng_state
@staticmethod @staticmethod
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, qkv_layout,
dropout_probability, is_training): scaling_factor, dropout_probability, is_training):
_check_valid_batch_dims(batch_dims) _check_valid_batch_dims(batch_dims)
assert SelfFusedAttnFwdPrimitive.outer_primitive is not None assert FusedAttnFwdPrimitive.outer_primitive is not None
qkv_bdim, _, _, seed_bdim = batch_dims q_bdim, *_, seed_bdim = batch_dims
out_bdims = qkv_bdim, qkv_bdim, seed_bdim out_bdims = q_bdim, q_bdim, seed_bdim
return SelfFusedAttnFwdPrimitive.outer_primitive.bind( return FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args,
*batched_args, attn_bias_type=attn_bias_type,
attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type,
attn_mask_type=attn_mask_type, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training), out_bdims is_training=is_training), out_bdims
@staticmethod @staticmethod
def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, def infer_sharding_from_operands(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor,
dropout_probability, is_training, mesh, arg_infos, dropout_probability, is_training, mesh, arg_infos,
result_infos): result_infos):
del attn_bias_type, attn_mask_type, scaling_factor del attn_bias_type, attn_mask_type, scaling_factor
del dropout_probability, is_training, result_infos del dropout_probability, is_training, result_infos
x_spec = get_padded_spec(arg_infos[0]) # (...batch, seqlen, 3, head, hidden) q_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-3], *x_spec[-2:])) k_spec = get_padded_spec(arg_infos[1])
softmax_aux_sharding = NamedSharding( match qkv_layout:
mesh, PartitionSpec(*x_spec[:-4], x_spec[-2], x_spec[-4], None)) case NVTE_QKV_Layout.NVTE_BS3HD:
# q_spec = (...batch, q_seqlen, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:]))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None))
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
# q_spec = (...batch, q_seqlen, head, hidden)
# k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-4]))
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
# q_spec = (...batch, q_seqlen, head, hidden)
# k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3]))
case _:
raise ValueError(f"Unsupported {qkv_layout=}")
rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
return (out_sharding, softmax_aux_sharding, rng_state_sharding) return (out_sharding, softmax_aux_sharding, rng_state_sharding)
@staticmethod @staticmethod
def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, def partition(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability,
mesh, arg_infos, result_infos): is_training, mesh, arg_infos, result_infos):
del result_infos out_sharding = result_infos[0].sharding
x_spec = get_padded_spec(arg_infos[0]) # (...batch, seqlen, 3, head, hidden) softmax_aux_sharding = result_infos[1].sharding
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-3], *x_spec[-2:])) rng_state_sharding = seed_sharding = NamedSharding(mesh,
softmax_aux_sharding = NamedSharding( PartitionSpec(get_all_mesh_axes(), None))
mesh, PartitionSpec(*x_spec[:-4], x_spec[-2], x_spec[-4], None)) arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [rng_state_sharding])
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
impl = partial(SelfFusedAttnFwdPrimitive.impl, impl = partial(FusedAttnFwdPrimitive.impl,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return mesh, impl, out_shardings, arg_shardings return mesh, impl, out_shardings, arg_shardings
register_primitive(SelfFusedAttnFwdPrimitive) register_primitive(FusedAttnFwdPrimitive)
def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray | None, seqlen: jnp.ndarray,
seed: jnp.ndarray | None, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
Wrapper for TE self fused attention fwd
Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2
"""
checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=qkv.dtype)
return SelfFusedAttnFwdPrimitive.outer_primitive.bind(qkv,
bias,
seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
class SelfFusedAttnBwdPrimitive(BasePrimitive): class FusedAttnBwdPrimitive(BasePrimitive):
""" """
Self Fused Attention Backward Primitive Fused Attention Backward Primitive
""" """
name = "te_self_fused_attn_backward" name = "te_fused_attn_backward"
multiple_results = True multiple_results = True
impl_static_args = (7, 8, 9, 10, 11) impl_static_args = (10, 11, 12, 13, 14, 15)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract(qkv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, doutput_aval, def abstract(q_aval, k_aval, v_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval,
seqlen_or_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor, doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type,
dropout_probability, is_training): attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training):
""" """
Self fused attention bwd abstract Fused attention bwd abstract
""" """
del softmax_aux_aval, rng_state_aval, seqlen_or_cu_seqlen_aval del softmax_aux_aval, rng_state_aval, output_aval
assert qkv_aval.dtype == bias_aval.dtype == output_aval.dtype == doutput_aval.dtype q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
*input_batch_shape, max_seqlen, nqkv, attn_heads, head_dim = qkv_aval.shape k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
assert nqkv == 3 v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0 bias_batch = bias_heads = 0
...@@ -2174,46 +2198,55 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2174,46 +2198,55 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape) bias_batch = reduce(operator.mul, bias_batch_shape)
input_batch = reduce(operator.mul, input_batch_shape) input_batch = reduce(operator.mul, batch_shape)
wkspace_shape, wkspace_dtype = \ wkspace_shape, wkspace_dtype = \
transformer_engine_jax.get_self_fused_attn_bwd_workspace_sizes( transformer_engine_jax.get_fused_attn_bwd_workspace_sizes(
input_batch, bias_batch, max_seqlen, attn_heads, bias_heads, head_dim, input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, bias_heads, head_dim, scaling_factor, dropout_probability, attn_bias_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), is_training)
)
dqkv_aval = qkv_aval.update(shape=qkv_aval.shape, dtype=qkv_dtype) dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype)
dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype)
dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
wkspace_aval = qkv_aval.update(shape=wkspace_shape, wkspace_aval = q_aval.update(shape=wkspace_shape,
dtype=te_dtype_to_jax_dtype(wkspace_dtype)) dtype=te_dtype_to_jax_dtype(wkspace_dtype))
return dqkv_aval, dbias_aval, wkspace_aval return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval
@staticmethod @staticmethod
def outer_abstract(*args, **kwargs): def outer_abstract(*args, **kwargs):
""" """
Self fused attention bwd outer primitive abstract Fused attention fwd outer primitive abstract
""" """
dqkv_aval, dbias_aval, _ = SelfFusedAttnBwdPrimitive.abstract(*args, **kwargs) dq_aval, dk_aval, dv_aval, dbias_aval, _ = \
return dqkv_aval, dbias_aval FusedAttnBwdPrimitive.abstract(*args, **kwargs)
return dq_aval, dk_aval, dv_aval, dbias_aval
@staticmethod @staticmethod
def lowering(ctx, qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen, *, def lowering(ctx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training): kv_cu_seqlen, *, attn_bias_type, attn_mask_type, qkv_layout, scaling_factor,
dropout_probability, is_training):
""" """
Self fused attention bwd lowering rules Fused attention bwd lowering rules
""" """
operands = [qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen] operands = [
q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen
]
operand_shapes = map(lambda x: x.type.shape, operands) operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [ out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out for output in ctx.avals_out
] ]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
qkv_aval, bias_aval, *_ = ctx.avals_in q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
*input_batch_shape, max_seqlen, _, attn_heads, head_dim = qkv_aval.shape
input_batch = reduce(operator.mul, input_batch_shape) batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
input_batch = reduce(operator.mul, batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0 bias_batch = bias_heads = 0
...@@ -2224,780 +2257,245 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2224,780 +2257,245 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
wkspace_aval = ctx.avals_out[-1] wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch, bias_batch, max_seqlen, max_seqlen, input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
attn_heads, attn_heads, bias_heads, head_dim, wkspace_aval.size, bias_heads, head_dim, wkspace_aval.size, scaling_factor, dropout_probability,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, attn_bias_type, attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(qkv_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), is_training)
is_training)
out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out return out
@staticmethod @staticmethod
def impl(qkv, bias, softmax_aux, rng_state, output, doutput, seqlen, attn_bias_type, def impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen,
attn_mask_type, scaling_factor, dropout_probability, is_training): attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability,
assert SelfFusedAttnBwdPrimitive.inner_primitive is not None is_training):
assert FusedAttnBwdPrimitive.inner_primitive is not None
cu_seqlen = generate_cu_seqlen(seqlen) q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
dqkv, dbias, _ = SelfFusedAttnBwdPrimitive.inner_primitive.bind( dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
qkv, q,
k,
v,
bias, bias,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
doutput, doutput,
cu_seqlen, q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return dqkv, dbias return dq, dk, dv, dbias
@staticmethod @staticmethod
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, qkv_layout,
dropout_probability, is_training): scaling_factor, dropout_probability, is_training):
_check_valid_batch_dims(batch_dims) _check_valid_batch_dims(batch_dims)
assert SelfFusedAttnBwdPrimitive.outer_primitive is not None assert FusedAttnBwdPrimitive.outer_primitive is not None
qkv_bdim, *_ = batch_dims q_bdim, k_bdim, v_bdim, *_ = batch_dims
out_bdims = qkv_bdim, qkv_bdim out_bdims = q_bdim, k_bdim, v_bdim, q_bdim
return SelfFusedAttnBwdPrimitive.outer_primitive.bind( return FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args,
*batched_args, attn_bias_type=attn_bias_type,
attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type,
attn_mask_type=attn_mask_type, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training), out_bdims is_training=is_training), out_bdims
@staticmethod @staticmethod
def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, def infer_sharding_from_operands(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor,
dropout_probability, is_training, mesh, arg_infos, dropout_probability, is_training, mesh, arg_infos,
result_infos): result_infos):
del attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, del attn_bias_type, attn_mask_type, qkv_layout, scaling_factor
del is_training, result_infos del dropout_probability, is_training, result_infos
x_spec = get_padded_spec(arg_infos[0]) q_spec = get_padded_spec(arg_infos[0])
bias_spec = get_padded_spec(arg_infos[1]) k_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
return (dx_sharding, dbias_sharding) return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
@staticmethod @staticmethod
def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, def partition(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability,
mesh, arg_infos, result_infos): is_training, mesh, arg_infos, result_infos):
del result_infos del result_infos
x_spec = get_padded_spec(arg_infos[0]) q_spec = get_padded_spec(arg_infos[0])
bias_spec = get_padded_spec(arg_infos[1]) k_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (dx_sharding, dbias_sharding) out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
def sharded_impl(qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen): def sharded_impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
local_dx, local_dbias = SelfFusedAttnBwdPrimitive.impl( kv_cu_seqlen):
qkv, local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl(
q,
k,
v,
bias, bias,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
doutput, doutput,
cu_seqlen, q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
global_dbias = local_dbias global_dbias = local_dbias
if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
return local_dx, global_dbias return local_dq, local_dk, local_dv, global_dbias
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(SelfFusedAttnBwdPrimitive) register_primitive(FusedAttnBwdPrimitive)
def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray, def fused_attn_fwd_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray, seqlen: jnp.ndarray,
rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float, dropout_probability: float, is_training: bool):
dropout_probability: float, is_training: bool):
""" """
Wrapper for TE self fused attention bwd Wrapper for TE self fused attention fwd
Return the gradients of self fused attention with packed qkv input Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
""" """
checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None assert bias is None
bias = jnp.zeros(0, dtype=qkv.dtype) bias = jnp.zeros(0, dtype=qkv.dtype)
return SelfFusedAttnBwdPrimitive.outer_primitive.bind(qkv, _not_used = jnp.zeros(0, qkv.dtype)
bias, return FusedAttnFwdPrimitive.outer_primitive.bind(qkv,
softmax_aux, _not_used,
rng_state, _not_used,
output, bias,
doutput, seqlen,
seqlen, seqlen,
attn_bias_type=attn_bias_type, seed,
attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type,
scaling_factor=scaling_factor, attn_mask_type=attn_mask_type,
dropout_probability=dropout_probability, qkv_layout=NVTE_QKV_Layout.NVTE_BS3HD,
is_training=is_training) scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
class CrossFusedAttnFwdPrimitive(BasePrimitive): def fused_attn_bwd_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray,
rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray,
seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
Wrapper for TE self fused attention bwd
Return the gradients of self fused attention with packed qkv input
""" """
Cross Fused Attention Forward Primitive if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=qkv.dtype)
dummy_input = jnp.zeros(0, dtype=qkv.dtype)
dqkv, *_, dbias = FusedAttnBwdPrimitive.outer_primitive.bind(
qkv,
dummy_input,
dummy_input,
bias,
softmax_aux,
rng_state,
output,
doutput,
seqlen,
seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BS3HD,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return dqkv, dbias
def fused_attn_fwd_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Wrapper for TE fused attention fwd with kvpacked inputs
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
""" """
name = "te_cross_fused_attn_forward" checker = _FusedAttnRNGStateChecker()
multiple_results = True seed = checker.check_seed(seed, dropout_probability, is_training)
impl_static_args = (6, 7, 8, 9, 10)
inner_primitive = None
outer_primitive = None
@staticmethod if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
def abstract(q_aval, kv_aval, bias_aval, q_seqlen_or_cu_seqlen_aval, assert bias is None
kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type, bias = jnp.zeros(0, dtype=q.dtype)
scaling_factor, dropout_probability, is_training):
"""
Cross fused attention fwd abstract
"""
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
assert q_dtype == kv_dtype == bias_dtype
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape return FusedAttnFwdPrimitive.outer_primitive.bind(q,
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = kv_aval.shape kv,
assert q_batch_shape == kv_batch_shape jnp.zeros(0, q.dtype),
assert q_head_dim == kv_head_dim bias,
assert nkv == 2 q_seqlen,
out_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) kv_seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
# backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
attn_bias_type, attn_mask_type, dropout_probability, attn_heads,
num_gqa_groups, q_max_seqlen, kv_max_seqlen,
q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: def fused_attn_bwd_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray,
softmax_dtype = q_dtype doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray,
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, 1) scaling_factor: float, dropout_probability: float, is_training: bool):
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) """
else: Wrapper for TE fused attention bwd with kvpacked inputs
raise ValueError(f'Unsupported {backend=}') Return the gradients of fused attention with packed kv input
softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype) """
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
# JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with assert bias is None
# 32-bit unsigned int to get the buffer size we need in the C++ kernel bias = jnp.zeros(0, dtype=q.dtype)
checker = _FusedAttnRNGStateChecker() dummy_input = jnp.zeros(0, q.dtype)
seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype) dq, dkv, _, dbias = FusedAttnBwdPrimitive.outer_primitive.bind(
assert seed_dtype == checker.rng_state_dtype q,
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) kv,
rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) dummy_input,
bias,
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: softmax_aux,
bias_batch = bias_heads = 0 rng_state,
else: output,
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape doutput,
bias_batch = reduce(operator.mul, bias_batch_shape) q_seqlen,
kv_seqlen,
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to attn_bias_type=attn_bias_type,
# prepare for the active fused-attn backend attn_mask_type=attn_mask_type,
input_batch = reduce(operator.mul, q_batch_shape) qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
wkspace_info = transformer_engine_jax.get_cross_fused_attn_fwd_workspace_sizes( scaling_factor=scaling_factor,
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, dropout_probability=dropout_probability,
attn_heads, num_gqa_groups, bias_heads, q_head_dim, is_training=is_training)
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, return dq, dkv, dbias
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
wkspace_aval = q_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Cross fused attention fwd outer primitive abstract
"""
out_aval, softmax_aux_aval, rng_state_aval, _ = \
CrossFusedAttnFwdPrimitive.abstract(*args, **kwargs)
return out_aval, softmax_aux_aval, rng_state_aval
@staticmethod
def lowering(ctx, q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Cross fused attention fwd lowering rules
"""
operands = [q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, kv_aval, bias_aval, *_ = ctx.avals_in
*input_batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, _, num_gqa_groups, _ = kv_aval.shape
input_batch = reduce(operator.mul, input_batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, head_dim,
wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(CrossFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(q, kv, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
assert CrossFusedAttnFwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
output, softmax_aux, rng_state, _ = CrossFusedAttnFwdPrimitive.inner_primitive.bind(
q,
kv,
bias,
q_cu_seqlen,
kv_cu_seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output, softmax_aux, rng_state
@staticmethod
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
_check_valid_batch_dims(batch_dims)
assert CrossFusedAttnFwdPrimitive.outer_primitive is not None
q_bdim, *_, seed_bdim = batch_dims
out_bdims = q_bdim, q_bdim, seed_bdim
return CrossFusedAttnFwdPrimitive.outer_primitive.bind(
*batched_args,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training), out_bdims
@staticmethod
def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, mesh, arg_infos,
result_infos):
del attn_bias_type, attn_mask_type, scaling_factor
del dropout_probability, is_training, result_infos
q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden)
kv_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, 2, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4]))
rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
return (out_sharding, softmax_aux_sharding, rng_state_sharding)
@staticmethod
def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
mesh, arg_infos, result_infos):
del result_infos
q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden)
kv_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, 2, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4]))
rng_state_sharding = seed_sharding = NamedSharding(mesh,
PartitionSpec(get_all_mesh_axes(), None))
arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
impl = partial(CrossFusedAttnFwdPrimitive.impl,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return mesh, impl, out_shardings, arg_shardings
register_primitive(CrossFusedAttnFwdPrimitive)
def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, q_seqlen: jnp.ndarray,
kv_seqlen: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
Wrapper for TE cross fused attention fwd
Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2
"""
checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
return CrossFusedAttnFwdPrimitive.outer_primitive.bind(q,
kv,
bias,
q_seqlen,
kv_seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
class CrossFusedAttnBwdPrimitive(BasePrimitive):
"""
Cross Fused Attention Backward Primitive
"""
name = "te_cross_fused_attn_backward"
multiple_results = True
impl_static_args = (9, 10, 11, 12, 13)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(q_aval, kv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval,
doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Cross fused attention bwd abstract
"""
del softmax_aux_aval, rng_state_aval, output_aval
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
assert q_dtype == kv_dtype == bias_dtype == doutput_dtype
assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = kv_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert nkv == 2
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
input_batch = reduce(operator.mul, q_batch_shape)
wkspace_shape, wkspace_dtype = \
transformer_engine_jax.get_cross_fused_attn_bwd_workspace_sizes(
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training
)
dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
dkv_aval = kv_aval.update(shape=kv_aval.shape, dtype=kv_dtype)
dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
wkspace_aval = q_aval.update(shape=wkspace_shape,
dtype=te_dtype_to_jax_dtype(wkspace_dtype))
return dq_aval, dkv_aval, dbias_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Cross fused attention fwd outer primitive abstract
"""
dq_aval, dkv_aval, dbias_aval, _ = \
CrossFusedAttnBwdPrimitive.abstract(*args, **kwargs)
return dq_aval, dkv_aval, dbias_aval
@staticmethod
def lowering(ctx, q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
kv_cu_seqlen, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
"""
Cross fused attention bwd lowering rules
"""
operands = [q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, kv_aval, bias_aval, *_ = ctx.avals_in
*input_batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, _, num_gqa_groups, _ = kv_aval.shape
input_batch = reduce(operator.mul, input_batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, head_dim,
wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(CrossFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen,
attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training):
assert CrossFusedAttnBwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
dq, dkv, dbias, _ = CrossFusedAttnBwdPrimitive.inner_primitive.bind(
q,
kv,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return dq, dkv, dbias
@staticmethod
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
_check_valid_batch_dims(batch_dims)
assert CrossFusedAttnBwdPrimitive.outer_primitive is not None
q_bdim, kv_bdim, *_ = batch_dims
out_bdims = q_bdim, kv_bdim, q_bdim
return CrossFusedAttnBwdPrimitive.outer_primitive.bind(
*batched_args,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training), out_bdims
@staticmethod
def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, mesh, arg_infos,
result_infos):
del attn_bias_type, attn_mask_type, scaling_factor
del dropout_probability, is_training, result_infos
q_spec = get_padded_spec(arg_infos[0])
kv_spec = get_padded_spec(arg_infos[1])
bias_spec = get_padded_spec(arg_infos[2])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
return (dq_sharding, dkv_sharding, dbias_sharding)
@staticmethod
def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
mesh, arg_infos, result_infos):
del result_infos
q_spec = get_padded_spec(arg_infos[0])
kv_spec = get_padded_spec(arg_infos[1])
bias_spec = get_padded_spec(arg_infos[2])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (dq_sharding, dkv_sharding, dbias_sharding)
def sharded_impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
kv_cu_seqlen):
local_dq, local_dkv, local_dbias = CrossFusedAttnBwdPrimitive.impl(
q,
kv,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
global_dbias = local_dbias
if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
return local_dq, local_dkv, global_dbias
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(CrossFusedAttnBwdPrimitive)
def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray,
doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Wrapper for TE cross fused attention bwd
Return the gradients of cross fused attention with packed kv input
"""
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
return CrossFusedAttnBwdPrimitive.outer_primitive.bind(q,
kv,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
class FusedAttnFwdPrimitive(BasePrimitive):
"""
Fused Attention Forward Primitive
Query, key, value are seperated tensors
"""
name = "te_fused_attn_forward"
multiple_results = True
impl_static_args = (7, 8, 9, 10, 11)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(q_aval, k_aval, v_aval, bias_aval, q_seqlen_or_cu_seqlen_aval,
kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
"""
Fused attention fwd abstract
"""
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
assert q_dtype == k_dtype == v_dtype == bias_dtype
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert k_aval.shape == v_aval.shape
out_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
# backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(q_dtype, k_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD, attn_bias_type,
attn_mask_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
softmax_dtype = q_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f'Unsupported {backend=}')
softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)
# JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
# 32-bit unsigned int to get the buffer size we need in the C++ kernel
checker = _FusedAttnRNGStateChecker()
seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
assert seed_dtype == checker.rng_state_dtype
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
# prepare for the active fused-attn backend
input_batch = reduce(operator.mul, q_batch_shape)
wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes(
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
wkspace_aval = q_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Fused attention fwd outer primitive abstract
"""
out_aval, softmax_aux_aval, rng_state_aval, _ = \
FusedAttnFwdPrimitive.abstract(*args, **kwargs)
return out_aval, softmax_aux_aval, rng_state_aval
@staticmethod
def lowering(ctx, q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Fused attention fwd lowering rules
"""
operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
*batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, num_gqa_groups, _ = k_aval.shape
assert k_aval.shape == v_aval.shape
input_batch = reduce(operator.mul, batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, head_dim,
wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(q, k, v, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
assert FusedAttnFwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind(
q,
k,
v,
bias,
q_cu_seqlen,
kv_cu_seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output, softmax_aux, rng_state
@staticmethod
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
_check_valid_batch_dims(batch_dims)
assert FusedAttnFwdPrimitive.outer_primitive is not None
q_bdim, *_, seed_bdim = batch_dims
out_bdims = q_bdim, q_bdim, seed_bdim
return FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training), out_bdims
@staticmethod
def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, mesh, arg_infos,
result_infos):
del attn_bias_type, attn_mask_type, scaling_factor
del dropout_probability, is_training, result_infos
q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden)
k_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3]))
rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
return (out_sharding, softmax_aux_sharding, rng_state_sharding)
@staticmethod
def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
mesh, arg_infos, result_infos):
del result_infos
q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden)
k_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3]))
rng_state_sharding = seed_sharding = NamedSharding(mesh,
PartitionSpec(get_all_mesh_axes(), None))
arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
impl = partial(FusedAttnFwdPrimitive.impl,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return mesh, impl, out_shardings, arg_shardings
register_primitive(FusedAttnFwdPrimitive)
def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray,
...@@ -3006,7 +2504,7 @@ def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda ...@@ -3006,7 +2504,7 @@ def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda
scaling_factor: float, dropout_probability: float, is_training: bool): scaling_factor: float, dropout_probability: float, is_training: bool):
""" """
Wrapper for TE fused attention fwd, where query, key, value are seperated tensors Wrapper for TE fused attention fwd, where query, key, value are seperated tensors
Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2 Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
""" """
checker = _FusedAttnRNGStateChecker() checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training) seed = checker.check_seed(seed, dropout_probability, is_training)
...@@ -3015,228 +2513,20 @@ def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda ...@@ -3015,228 +2513,20 @@ def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda
assert bias is None assert bias is None
bias = jnp.zeros(0, dtype=q.dtype) bias = jnp.zeros(0, dtype=q.dtype)
return FusedAttnFwdPrimitive.outer_primitive.bind(q, return FusedAttnFwdPrimitive.outer_primitive.bind(
k, q,
v, k,
bias, v,
q_seqlen, bias,
kv_seqlen, q_seqlen,
seed, kv_seqlen,
attn_bias_type=attn_bias_type, seed,
attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type,
scaling_factor=scaling_factor, attn_mask_type=attn_mask_type,
dropout_probability=dropout_probability, qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD,
is_training=is_training) scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
class FusedAttnBwdPrimitive(BasePrimitive):
"""
Fused Attention Backward Primitive
"""
name = "te_fused_attn_backward"
multiple_results = True
impl_static_args = (10, 11, 12, 13, 14)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(q_aval, k_aval, v_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval,
doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Fused attention bwd abstract
"""
del softmax_aux_aval, rng_state_aval, output_aval
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert k_aval.shape == v_aval.shape
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
input_batch = reduce(operator.mul, q_batch_shape)
wkspace_shape, wkspace_dtype = \
transformer_engine_jax.get_fused_attn_bwd_workspace_sizes(
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training
)
dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype)
dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype)
dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
wkspace_aval = q_aval.update(shape=wkspace_shape,
dtype=te_dtype_to_jax_dtype(wkspace_dtype))
return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Fused attention fwd outer primitive abstract
"""
dq_aval, dk_aval, dv_aval, dbias_aval, _ = \
FusedAttnBwdPrimitive.abstract(*args, **kwargs)
return dq_aval, dk_aval, dv_aval, dbias_aval
@staticmethod
def lowering(ctx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
kv_cu_seqlen, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
"""
Fused attention bwd lowering rules
"""
operands = [
q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
*batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, num_gqa_groups, _ = k_aval.shape
assert k_aval.shape == v_aval.shape
input_batch = reduce(operator.mul, batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, head_dim,
wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen,
attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training):
assert FusedAttnBwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return dq, dk, dv, dbias
@staticmethod
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
_check_valid_batch_dims(batch_dims)
assert FusedAttnBwdPrimitive.outer_primitive is not None
q_bdim, k_bdim, v_bdim, *_ = batch_dims
out_bdims = q_bdim, k_bdim, v_bdim, q_bdim
return FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training), out_bdims
@staticmethod
def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, mesh, arg_infos,
result_infos):
del attn_bias_type, attn_mask_type, scaling_factor
del dropout_probability, is_training, result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
@staticmethod
def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
mesh, arg_infos, result_infos):
del result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
def sharded_impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
kv_cu_seqlen):
local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
global_dbias = local_dbias
if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
return local_dq, local_dk, local_dv, global_dbias
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(FusedAttnBwdPrimitive)
def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray,
...@@ -3251,22 +2541,23 @@ def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda ...@@ -3251,22 +2541,23 @@ def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None assert bias is None
bias = jnp.zeros(0, dtype=q.dtype) bias = jnp.zeros(0, dtype=q.dtype)
return FusedAttnBwdPrimitive.outer_primitive.bind(
return FusedAttnBwdPrimitive.outer_primitive.bind(q, q,
k, k,
v, v,
bias, bias,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
doutput, doutput,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD,
dropout_probability=dropout_probability, scaling_factor=scaling_factor,
is_training=is_training) dropout_probability=dropout_probability,
is_training=is_training)
class GeluPrimitive(BasePrimitive): class GeluPrimitive(BasePrimitive):
......
...@@ -49,10 +49,6 @@ pybind11::dict Registrations() { ...@@ -49,10 +49,6 @@ pybind11::dict Registrations() {
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward); EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward);
dict["te_scaled_upper_triang_masked_softmax_backward"] = dict["te_scaled_upper_triang_masked_softmax_backward"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward); EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward);
dict["te_self_fused_attn_forward"] = EncapsulateFunction(SelfFusedAttnForward);
dict["te_self_fused_attn_backward"] = EncapsulateFunction(SelfFusedAttnBackward);
dict["te_cross_fused_attn_forward"] = EncapsulateFunction(CrossFusedAttnForward);
dict["te_cross_fused_attn_backward"] = EncapsulateFunction(CrossFusedAttnBackward);
dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward);
dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward);
return dict; return dict;
...@@ -72,10 +68,6 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -72,10 +68,6 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes); m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes);
m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes); m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes);
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes); m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
m.def("get_self_fused_attn_fwd_workspace_sizes", &GetSelfFusedAttnForwardWorkspaceSizes);
m.def("get_self_fused_attn_bwd_workspace_sizes", &GetSelfFusedAttnBackwardWorkspaceSizes);
m.def("get_cross_fused_attn_fwd_workspace_sizes", &GetCrossFusedAttnForwardWorkspaceSizes);
m.def("get_cross_fused_attn_bwd_workspace_sizes", &GetCrossFusedAttnBackwardWorkspaceSizes);
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
......
...@@ -11,14 +11,12 @@ ...@@ -11,14 +11,12 @@
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cudnn.h> #include <cudnn.h>
#include <functional>
#include <numeric>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <vector> #include <vector>
#include "common/common.h" #include "common/common.h"
#include "common/util/cuda_runtime.h" #include "common/util/logging.h"
#include "transformer_engine/activation.h" #include "transformer_engine/activation.h"
#include "transformer_engine/cast.h" #include "transformer_engine/cast.h"
#include "transformer_engine/fused_attn.h" #include "transformer_engine/fused_attn.h"
...@@ -96,13 +94,13 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t paddin ...@@ -96,13 +94,13 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t paddin
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t wkspace_size, float scaling_factor, float dropout_probability, size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
DType dtype, DType wkspace_dtype, bool is_training) { bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{ return PackOpaque(CustomCallFusedAttnDescriptor{
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
bias_heads, head_dim, wkspace_size, scaling_factor, dropout_probability, bias_heads, head_dim, wkspace_size, scaling_factor, dropout_probability, bias_type,
bias_type, mask_type, dtype, wkspace_dtype, is_training}); mask_type, qkv_layout, dtype, wkspace_dtype, is_training});
} }
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream, void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
...@@ -942,12 +940,12 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, ...@@ -942,12 +940,12 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads, size_t q_attn_heads, size_t kv_attn_heads,
size_t q_max_seqlen, size_t kv_max_seqlen, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim) { size_t head_dim) {
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_num_heads, kv_num_heads, q_max_seqlen, kv_max_seqlen, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen,
head_dim); head_dim);
return backend; return backend;
} }
...@@ -1029,244 +1027,31 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, ...@@ -1029,244 +1027,31 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack,
} }
} }
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes( pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t max_seqlen,
size_t attn_heads, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto qkv_shape = std::vector<size_t>{input_batch * max_seqlen, 3, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, max_seqlen, max_seqlen};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto o_tensor = TensorWrapper(
nullptr, std::vector<size_t>{input_batch, max_seqlen, attn_heads, head_dim}, dtype);
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, attn_heads, max_seqlen, max_seqlen, head_dim);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input buffers from XLA
void *qkv = buffers[0];
void *bias = buffers[1];
void *cu_seqlens = buffers[2];
void *seed = buffers[3];
// output buffers from XLA
void *output = buffers[4];
void *softmax_aux = buffers[5];
void *rng_state = buffers[6];
void *workspace = buffers[7];
// tensor sizes
auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch;
auto max_seqlen = descriptor.q_max_seqlen;
auto attn_heads = descriptor.attn_heads;
auto bias_heads = descriptor.bias_heads;
auto head_dim = descriptor.head_dim;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{input_batch * max_seqlen, 3, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, max_seqlen, max_seqlen};
// input tensors
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
auto o_tensor = TensorWrapper(
output, std::vector<size_t>{input_batch * max_seqlen, attn_heads, head_dim}, dtype);
// prep RNG state
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, attn_heads, max_seqlen, max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, max_seqlen, max_seqlen, backend, stream);
// auxiliary tensors (to be propagated to the backward pass later)
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
softmax_aux);
// cuDNN workspace
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), max_seqlen, descriptor.is_training,
descriptor.scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t max_seqlen,
size_t attn_heads, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto qkv_shape = std::vector<size_t>{input_batch * max_seqlen, 3, attn_heads, head_dim};
auto output_shape = std::vector<size_t>{input_batch * max_seqlen, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, max_seqlen, max_seqlen};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input buffers from XLA
void *qkv = buffers[0];
void *bias = buffers[1];
void *softmax_aux = buffers[2];
void *rng_state = buffers[3];
void *output = buffers[4];
void *doutput = buffers[5];
void *cu_seqlens = buffers[6];
// output buffers from XLA
void *dqkv = buffers[7];
void *dbias = buffers[8];
void *workspace = buffers[9];
// tensor sizes
auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch;
auto max_seqlen = descriptor.q_max_seqlen;
auto attn_heads = descriptor.attn_heads;
auto bias_heads = descriptor.bias_heads;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{input_batch * max_seqlen, 3, attn_heads, head_dim};
auto output_shape = std::vector<size_t>{input_batch * max_seqlen, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, max_seqlen, max_seqlen};
// input tensors
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
// auxiliary tensors (propagated from the forward pass)
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, attn_heads, max_seqlen, max_seqlen, head_dim);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias);
// cuDNN workspace
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_input_tensors);
}
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) { NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; // For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
// For kv_packed
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// For separate q, k, v
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_shape = k_shape;
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype); auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
// FP16/BF16 doesn't use this tensor // F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor = TensorWrapper(nullptr, q_shape, dtype); auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);
...@@ -1281,292 +1066,133 @@ pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes( ...@@ -1281,292 +1066,133 @@ pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(), if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
s_tensor.data(), o_tensor.data(), &aux_output_tensors, assert(q_max_seqlen == kv_max_seqlen);
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), nvte_fused_attn_fwd_qkvpacked(
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
is_training, scaling_factor, dropout_probability, qkv_layout, &aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_rng_state_tensor.data(),
bias_type, mask_type, query_workspace_tensor.data(), nullptr); q_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(),
} s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
size_t opaque_len) { is_training, scaling_factor, dropout_probability, qkv_layout,
const CustomCallFusedAttnDescriptor &descriptor = bias_type, mask_type, query_workspace_tensor.data(), nullptr);
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
// input buffers from XLA s_tensor.data(), o_tensor.data(), &aux_output_tensors,
void *q = buffers[0]; q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
void *kv = buffers[1]; dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
void *bias = buffers[2]; scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
void *q_cu_seqlens = buffers[3]; query_workspace_tensor.data(), nullptr);
void *kv_cu_seqlens = buffers[4]; } else {
void *seed = buffers[5]; NVTE_ERROR("Unsupported QKVLayout.");
}
// output buffers from XLA
void *output = buffers[6];
void *softmax_aux = buffers[7];
void *rng_state = buffers[8];
void *workspace = buffers[9];
// tensor sizes
auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto attn_heads = descriptor.attn_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto bias_heads = descriptor.bias_heads;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
auto o_tensor = TensorWrapper(output, q_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
// prep RNG state
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
// auxiliary tensors (to be propagated to the backward pass later)
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
softmax_aux);
// cuDNN workspace
auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
descriptor.wkspace_dtype);
nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.is_training, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors); auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
} }
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) { bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype); auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
// FP16/BF16 doesn't use this tensor auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); auto bias_shape = std::vector<size_t>{1, attn_heads, q_max_seqlen, kv_max_seqlen};
auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype); auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto q_cu_seqlens_tensor = // F16 doesn't use s_tensor
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32); auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input buffers from XLA
void *q = buffers[0];
void *kv = buffers[1];
void *bias = buffers[2];
void *softmax_aux = buffers[3];
void *rng_state = buffers[4];
void *output = buffers[5];
void *doutput = buffers[6];
void *q_cu_seqlens = buffers[7];
void *kv_cu_seqlens = buffers[8];
// output buffers from XLA
void *dq = buffers[9];
void *dkv = buffers[10];
void *dbias = buffers[11];
void *workspace = buffers[12];
// tensor sizes
auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto attn_heads = descriptor.attn_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto bias_heads = descriptor.bias_heads;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto q_cu_seqlens_tensor = auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor = auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// auxiliary tensors (propagated from the forward pass)
NVTETensorPack aux_input_tensors; NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors); nvte_tensor_pack_create(&aux_input_tensors);
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias);
// cuDNN workspace
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_input_tensors);
}
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); assert(q_max_seqlen == kv_max_seqlen);
auto qkv_shape = std::vector<size_t>{batch_size * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
nvte_fused_attn_bwd_qkvpacked(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape =
std::vector<size_t>{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_shape = k_shape;
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(),
dv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
}
auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
} }
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor = const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input buffers from XLA /* Input buffers from XLA */
void *q = buffers[0]; /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
void *k = buffers[1];
void *v = buffers[2];
void *bias = buffers[3]; void *bias = buffers[3];
void *q_cu_seqlens = buffers[4]; void *q_cu_seqlens = buffers[4];
void *kv_cu_seqlens = buffers[5]; void *kv_cu_seqlens = buffers[5];
void *seed = buffers[6]; void *seed = buffers[6];
// output buffers from XLA /* Output buffer from XLA */
void *output = buffers[7]; void *output = buffers[7];
void *softmax_aux = buffers[8]; void *softmax_aux = buffers[8];
void *rng_state = buffers[9]; void *rng_state = buffers[9];
void *workspace = buffers[10]; void *workspace = buffers[10];
// tensor sizes /* Descriptor */
auto input_batch = descriptor.input_batch; auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch; auto bias_batch = descriptor.bias_batch;
auto q_max_seqlen = descriptor.q_max_seqlen; auto q_max_seqlen = descriptor.q_max_seqlen;
...@@ -1579,29 +1205,26 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -1579,29 +1205,26 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto dropout_probability = descriptor.dropout_probability; auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type; auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type; auto mask_type = descriptor.mask_type;
auto qkv_layout = descriptor.qkv_layout;
auto dtype = descriptor.dtype;
/* Input tensors */
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape; auto v_shape = k_shape;
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
// output tensors /* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16 auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto o_tensor = TensorWrapper(output, q_shape, dtype); auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto o_tensor = TensorWrapper(output, o_shape, dtype);
auto q_cu_seqlens_tensor = auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32); TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor = auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32); TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
// prep RNG state /* Prepare RNG state */
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64); auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
...@@ -1609,22 +1232,59 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -1609,22 +1232,59 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
head_dim); head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
// auxiliary tensors (to be propagated to the backward pass later) /* Auxiliary tensors (to be propagated to the backward pass later) */
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend, PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
softmax_aux); softmax_aux);
// cuDNN workspace /* cuDNN workspace */
auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size}, auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
descriptor.wkspace_dtype); descriptor.wkspace_dtype);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), /* Call the underly NVTE API */
s_tensor.data(), o_tensor.data(), &aux_output_tensors, if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), auto qkv = buffers[0];
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
descriptor.is_training, scaling_factor, dropout_probability, qkv_layout, auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
bias_type, mask_type, workspace_tensor.data(), stream); nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), rng_state_tensor.data(), q_max_seqlen,
descriptor.is_training, descriptor.scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv = buffers[1];
auto kv_shape =
std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k = buffers[1];
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v = buffers[2];
auto v_shape = k_shape;
auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_output_tensors);
} }
...@@ -1632,10 +1292,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -1632,10 +1292,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) { NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape; auto v_shape = k_shape;
...@@ -1682,10 +1340,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -1682,10 +1340,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
const CustomCallFusedAttnDescriptor &descriptor = const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input buffers from XLA /* Input buffers from XLA */
void *q = buffers[0]; /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
void *k = buffers[1];
void *v = buffers[2];
void *bias = buffers[3]; void *bias = buffers[3];
void *softmax_aux = buffers[4]; void *softmax_aux = buffers[4];
void *rng_state = buffers[5]; void *rng_state = buffers[5];
...@@ -1694,14 +1350,12 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -1694,14 +1350,12 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
void *q_cu_seqlens = buffers[8]; void *q_cu_seqlens = buffers[8];
void *kv_cu_seqlens = buffers[9]; void *kv_cu_seqlens = buffers[9];
// output buffers from XLA /* Output buffer from XLA */
void *dq = buffers[10]; /* Buffers[10-12] are dq, dk, dv, which are parsed later for different qkv_layout */
void *dk = buffers[11];
void *dv = buffers[12];
void *dbias = buffers[13]; void *dbias = buffers[13];
void *workspace = buffers[14]; void *workspace = buffers[14];
// tensor sizes /* Descriptor */
auto input_batch = descriptor.input_batch; auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch; auto bias_batch = descriptor.bias_batch;
auto q_max_seqlen = descriptor.q_max_seqlen; auto q_max_seqlen = descriptor.q_max_seqlen;
...@@ -1714,36 +1368,26 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -1714,36 +1368,26 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dropout_probability = descriptor.dropout_probability; auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type; auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type; auto mask_type = descriptor.mask_type;
auto qkv_layout = descriptor.qkv_layout;
auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; /* Input tensors */
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
auto output_tensor = TensorWrapper(output, output_shape, dtype); auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// output tensors /* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16 auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dk_tensor = TensorWrapper(dk, k_shape, dtype);
auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto q_cu_seqlens_tensor = auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32); TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor = auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32); TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
// auxiliary tensors (propagated from the forward pass) /* Auxiliary tensors (propagated from the forward pass) */
NVTETensorPack aux_input_tensors; NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors); nvte_tensor_pack_create(&aux_input_tensors);
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
...@@ -1751,20 +1395,73 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -1751,20 +1395,73 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias); rng_state, bias);
// cuDNN workspace /* cuDNN workspace */
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size}; auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
auto wkspace_dtype = descriptor.wkspace_dtype; auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), /* Call the underly NVTE API */
doutput_tensor.data(), if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
s_tensor.data(), // not used for F16 auto qkv = buffers[0];
s_tensor.data(), // not used for F16 auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
dbias_tensor.data(), q_cu_seqlens_tensor.data(), auto dqkv = buffers[10];
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
dropout_probability, qkv_layout, bias_type, mask_type, nvte_fused_attn_bwd_qkvpacked(
workspace_tensor.data(), stream); qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv = buffers[1];
auto kv_shape =
std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
auto dq = buffers[10];
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dkv = buffers[11];
auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k = buffers[1];
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v = buffers[2];
auto v_shape = k_shape;
auto v_tensor = TensorWrapper(v, v_shape, dtype);
auto dq = buffers[10];
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dk = buffers[11];
auto dk_tensor = TensorWrapper(dk, k_shape, dtype);
auto dv = buffers[12];
auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(),
dv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
nvte_tensor_pack_destroy(&aux_input_tensors); nvte_tensor_pack_destroy(&aux_input_tensors);
} }
......
...@@ -118,17 +118,18 @@ struct CustomCallFusedAttnDescriptor { ...@@ -118,17 +118,18 @@ struct CustomCallFusedAttnDescriptor {
float dropout_probability; float dropout_probability;
NVTE_Bias_Type bias_type; NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type; NVTE_Mask_Type mask_type;
NVTE_QKV_Layout qkv_layout;
DType dtype; DType dtype;
DType wkspace_dtype; DType wkspace_dtype;
bool is_training; bool is_training;
}; };
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t wkspace_size, float scaling_factor, float dropout_probability, size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
DType dtype, DType wkspace_dtype, bool is_training); bool is_training);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
...@@ -207,55 +208,19 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, ...@@ -207,55 +208,19 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len); std::size_t opaque_len);
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t max_seqlen,
size_t attn_heads, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t max_seqlen,
size_t attn_heads, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes( pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training);
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training);
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
...@@ -26,7 +26,7 @@ from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP ...@@ -26,7 +26,7 @@ from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax from .module import LayerNorm, Softmax
from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from ..fused_attn import is_fused_attn_kernel_available, canonicalize_attn_mask_type from ..fused_attn import is_fused_attn_kernel_available, canonicalize_attn_mask_type
from ..fused_attn import self_fused_attn, cross_fused_attn, fused_attn from ..fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
from ..sharding import num_of_devices from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
...@@ -190,16 +190,19 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-publi ...@@ -190,16 +190,19 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-publi
def convert_to_softmax_type(attn_mask_type, mask): def convert_to_softmax_type(attn_mask_type, mask):
"""Convert the attn_mask_type to SoftmaxType""" """Convert the attn_mask_type to SoftmaxType"""
# mask is ignored for no_mask and causal_mask
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
mask = None
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]: if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]:
if mask is not None: if mask is not None:
return SoftmaxType.SCALED_MASKED return SoftmaxType.SCALED_MASKED, mask
return SoftmaxType.SCALED return SoftmaxType.SCALED, mask
raise ValueError(f"Unsupported {attn_mask_type=}, " raise ValueError(f"Unsupported {attn_mask_type=}, supported attn_mask_type="
"supported attn_mask_type = {'causal', 'padding'}") "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}")
softmax_type = convert_to_softmax_type(self.attn_mask_type, mask) softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask)
attn_weights = Softmax(softmax_type=softmax_type, attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor)(attn_weights, mask, scale_factor=fused_scale_factor)(attn_weights, mask,
...@@ -266,15 +269,15 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -266,15 +269,15 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
qkv_packed = query qkv_packed = query
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4]) qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
x = self_fused_attn(qkv_packed, x = fused_attn_qkvpacked(qkv_packed,
bias, bias,
mask, mask,
seed, seed,
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic) is_training=not deterministic)
elif self.qkv_layout == QKVLayout.BSHD_BS2HD: elif self.qkv_layout == QKVLayout.BSHD_BS2HD:
"""kvpacked format, treat """kvpacked format, treat
query: query tensor, shape = [..., h, d] query: query tensor, shape = [..., h, d]
...@@ -285,16 +288,16 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -285,16 +288,16 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
query = query.transpose([1, 0, 2, 3]) query = query.transpose([1, 0, 2, 3])
kv_packed = kv_packed.transpose([1, 0, 2, 3, 4]) kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
x = cross_fused_attn(query, x = fused_attn_kvpacked(query,
kv_packed, kv_packed,
bias, bias,
mask, mask,
seed, seed,
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic) is_training=not deterministic)
elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD: elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
query = query.transpose([1, 0, 2, 3]) query = query.transpose([1, 0, 2, 3])
...@@ -358,11 +361,27 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method ...@@ -358,11 +361,27 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method
attention_dropout: float, default = 0.0 attention_dropout: float, default = 0.0
Dropout probability for the dropout op after the softmax. Dropout probability for the dropout op after the softmax.
attn_mask_type: str, default = 'causal' attn_mask_type: str, default = 'causal'
Type of the attention mask passed into softmax operation in the self attention. This parameter specifies the type of attention mask to be applied during the softmax
Available options: {'no_mask', 'padding', 'causal', 'causal_padding'} operation.
Introduced in v0.10.0. Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}
Each described below:
* no_mask: No attention mask is applied. This means the attention will consider the
full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
:attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.
.. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.
attn_bias_type: Optional[str], default = None attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the self attention. Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
When default is present, the type is automatically decided by the MHA's bias parameter. When default is present, the type is automatically decided by the MHA's bias parameter.
Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used. Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used.
...@@ -438,6 +457,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method ...@@ -438,6 +457,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method
mask: jax.numpy.ndarray, default = None mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out the attention softmax input. Boolean tensor used to mask out the attention softmax input.
:attr:`True` means to mask out the corresponding values. :attr:`True` means to mask out the corresponding values.
Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
bias: jax.numpy.ndarray, default = None bias: jax.numpy.ndarray, default = None
A tensor used to shift attention softmax input. A tensor used to shift attention softmax input.
*: *:
...@@ -639,9 +659,25 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -639,9 +659,25 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attention_dropout: float, default = 0.0 attention_dropout: float, default = 0.0
Dropout probability for the dropout op after the softmax. Dropout probability for the dropout op after the softmax.
attn_mask_type: str, default = 'causal' attn_mask_type: str, default = 'causal'
Type of the attention mask passed into softmax operation in the attention. This parameter specifies the type of attention mask to be applied during the softmax
Available options: {'no_mask', 'padding', 'causal', 'causal_padding'} operation.
Introduced in v0.10.0. Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}
Each described below:
* no_mask: No attention mask is applied. This means the attention will consider the
full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
:attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.
.. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.
attn_bias_type: Optional[str], default = None attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention. Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
...@@ -809,6 +845,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -809,6 +845,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
mask: jax.numpy.ndarray, default = None mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out the attention softmax input. Boolean tensor used to mask out the attention softmax input.
:attr:`True` means mask out the corresponding values. :attr:`True` means mask out the corresponding values.
Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
bias: jax.numpy.ndarray, default = None bias: jax.numpy.ndarray, default = None
A tensor used to shift the attention softmax input. A tensor used to shift the attention softmax input.
* *
...@@ -1299,9 +1336,25 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1299,9 +1336,25 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
is added after self-attention.this can be used for structures like `T5` is added after self-attention.this can be used for structures like `T5`
Transformer in conjunction with the TransformerLayerType.ENCODER option. Transformer in conjunction with the TransformerLayerType.ENCODER option.
self_attn_mask_type: str, default = 'causal' self_attn_mask_type: str, default = 'causal'
Type of the attention mask passed into softmax operation in the self attention. This parameter specifies the type of attention mask to be applied during the softmax
Available options: {'no_mask', 'padding', 'causal', 'causal_padding'} operation in the self attention.
Introduced in v0.10.0. Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}
Each described below:
* no_mask: No attention mask is applied. This means the self attention will consider the
full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
:attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.
.. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.
self_attn_bias_type: Optional[str], default = None self_attn_bias_type: Optional[str], default = None
Type of the attention bias passed into the self attention. Type of the attention bias passed into the self attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
...@@ -1420,9 +1473,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1420,9 +1473,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
:attr:`layer_type=TransformerLayerType.DECODER`. :attr:`layer_type=TransformerLayerType.DECODER`.
attention_mask : jax.numpy.ndarray, default = None attention_mask : jax.numpy.ndarray, default = None
Boolean tensor used to mask out self-attention softmax input. Boolean tensor used to mask out self-attention softmax input.
:attr:`True` means mask out the corresponding values.
Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'.
encoder_decoder_mask: jax.numpy.ndarray, default = None encoder_decoder_mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out cross-attention softmax input when Boolean tensor used to mask out cross-attention softmax input when
:attr:`layer_type=TransformerLayerType.DECODER`. :attr:`layer_type=TransformerLayerType.DECODER`.
:attr:`True` means mask out the corresponding values.
deterministic: bool, default = False deterministic: bool, default = False
Disable dropout layers if set to True. Disable dropout layers if set to True.
decode: bool, default = False decode: bool, default = False
......
...@@ -14,20 +14,29 @@ from transformer_engine_jax import NVTE_Mask_Type ...@@ -14,20 +14,29 @@ from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout from transformer_engine_jax import NVTE_QKV_Layout
from .cpp_extensions import FusedAttnHelper from .cpp_extensions import FusedAttnHelper
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd from .cpp_extensions import fused_attn_fwd_kvpacked, fused_attn_bwd_kvpacked
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd from .cpp_extensions import fused_attn_fwd_qkvpacked, fused_attn_bwd_qkvpacked
from .cpp_extensions import fused_attn_fwd, fused_attn_bwd from .cpp_extensions import fused_attn_fwd, fused_attn_bwd
class AttnBiasType(Enum): class AttnBiasType(Enum):
"""Attention Bias Type.""" """
NO_BIAS: Softmax is performed as softmax(scale * qk)
PRE_SCALE_BIAS: Softmax is performed as softmax(scale * (qk + bias))
POST_SCALE_BIAS: Softmax is performed as softmax(scale * qk + bias)
"""
NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS
class AttnMaskType(Enum): class AttnMaskType(Enum):
"""Attention Mask Type.""" """
NO_MASK: No attention mask is applied.
PADDING_MASK: Indicates the presence of paddings at the end of each sequence.
CAUSAL_MASK: An upper triangular mask is applied to the softmax inputs.
PADDING_CAUSAL_MASK: A combination of both causal and padding masks.
"""
NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
...@@ -47,99 +56,105 @@ def canonicalize_attn_mask_type(attn_mask_type: str): ...@@ -47,99 +56,105 @@ def canonicalize_attn_mask_type(attn_mask_type: str):
The overhead between padding and non-padding version should be small. The overhead between padding and non-padding version should be small.
However, we will lease this limitation in the near feature. However, we will lease this limitation in the near feature.
""" """
if attn_mask_type in ['causal', 'padding_causal']: match attn_mask_type:
return AttnMaskType.PADDING_CAUSAL_MASK case 'no_mask':
if attn_mask_type in ['no_mask', 'padding']: return AttnMaskType.NO_MASK
return AttnMaskType.PADDING_MASK case 'padding':
raise ValueError(f"Unsupported {attn_mask_type=}, " return AttnMaskType.PADDING_MASK
"supported attn_mask_type={'no_mask', 'padding', 'causal', 'padding_causal'}") case 'causal':
return AttnMaskType.CAUSAL_MASK
case 'padding_causal' | 'causal_padding':
def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type, return AttnMaskType.PADDING_CAUSAL_MASK
dropout_probability, num_heads_q, num_heads_kv, max_seqlen_q, raise ValueError(f"Unsupported {attn_mask_type=}, supported attn_mask_type="
max_seqlen_kv, head_dim): "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}")
def is_fused_attn_kernel_available(q_dtype, kv_dtype, qkv_layout, attn_bias_type, attn_mask_type,
dropout_probability, q_num_heads, kv_num_heads, q_max_seqlen,
kv_max_seqlen, head_dim):
""" """
To check whether the fused attention kernel is available To check whether the fused attention kernel is supported
""" """
return FusedAttnHelper(q_type, kv_type, qkv_layout.value, attn_bias_type.value, return FusedAttnHelper(q_dtype, kv_dtype, qkv_layout.value, attn_bias_type.value,
attn_mask_type.value, dropout_probability, num_heads_q, num_heads_kv, attn_mask_type.value, dropout_probability, q_num_heads, kv_num_heads,
max_seqlen_q, max_seqlen_kv, head_dim).is_fused_attn_kernel_available() q_max_seqlen, kv_max_seqlen, head_dim).is_fused_attn_kernel_available()
def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray, def fused_attn_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
seed: jnp.ndarray | None, attn_bias_type: AttnBiasType, seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float, attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool): dropout_probability: float, is_training: bool):
""" """
Self fused attention wrapper Fused attention with the qkvpacked inputs
""" """
output = _self_fused_attn(qkv, output = _fused_attn_qkvpacked(qkv,
bias, bias,
mask, mask,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8)) @partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray, def _fused_attn_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
seed: jnp.ndarray | None, attn_bias_type: AttnBiasType, seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float, attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool): dropout_probability: float, is_training: bool):
output, _ = _self_fused_attn_fwd_rule(qkv, bias, mask, seed, attn_bias_type, attn_mask_type, output, _ = _fused_attn_fwd_qkvpacked_rule(qkv, bias, mask, seed, attn_bias_type,
scaling_factor, dropout_probability, is_training) attn_mask_type, scaling_factor, dropout_probability,
is_training)
return output return output
def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None, def _fused_attn_fwd_qkvpacked_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
mask: jnp.ndarray, seed: jnp.ndarray | None, seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, scaling_factor: float,
attn_mask_type: AttnMaskType, dropout_probability: float, is_training: bool):
scaling_factor: float, dropout_probability: float, if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
is_training: bool):
if mask is None:
batch, seqlen, *_ = qkv.shape batch, seqlen, *_ = qkv.shape
actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32) actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32)
else: else:
assert mask is not None
mask = jnp.logical_not(mask) mask = jnp.logical_not(mask)
actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
output, softmax_aux, rng_state = self_fused_attn_fwd(qkv, output, softmax_aux, rng_state = fused_attn_fwd_qkvpacked(
bias, qkv,
actual_seqlen, bias,
seed, actual_seqlen,
attn_bias_type=attn_bias_type.value, seed,
attn_mask_type=attn_mask_type.value, attn_bias_type=attn_bias_type.value,
scaling_factor=scaling_factor, attn_mask_type=attn_mask_type.value,
dropout_probability=dropout_probability, scaling_factor=scaling_factor,
is_training=is_training) dropout_probability=dropout_probability,
is_training=is_training)
output = checkpoint_name(output, 'context') output = checkpoint_name(output, 'context')
softmax_aux = checkpoint_name(softmax_aux, 'context') softmax_aux = checkpoint_name(softmax_aux, 'context')
rng_state = checkpoint_name(rng_state, 'context') rng_state = checkpoint_name(rng_state, 'context')
return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen) return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen)
def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, def _fused_attn_bwd_qkvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor,
is_training, ctx, dz): dropout_probability, is_training, ctx, dz):
qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx
grad_qkv, grad_bias = self_fused_attn_bwd(qkv, grad_qkv, grad_bias = fused_attn_bwd_qkvpacked(qkv,
bias, bias,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
dz, dz,
actual_seqlen, actual_seqlen,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
if attn_bias_type == AttnBiasType.NO_BIAS: if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None grad_bias = None
...@@ -147,91 +162,96 @@ def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dr ...@@ -147,91 +162,96 @@ def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dr
return grad_qkv, grad_bias, None, None return grad_qkv, grad_bias, None, None
_self_fused_attn.defvjp(_self_fused_attn_fwd_rule, _self_fused_attn_bwd_rule) _fused_attn_qkvpacked.defvjp(_fused_attn_fwd_qkvpacked_rule, _fused_attn_bwd_qkvpacked_rule)
def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, def fused_attn_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, seed: jnp.ndarray, attn_bias_type: AttnBiasType,
scaling_factor: float, dropout_probability: float, is_training: bool): attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
""" """
Cross multi-head attention wrapper Fused attention with the kvpacked inputs
""" """
output = _cross_fused_attn(q, output = _fused_attn_kvpacked(q,
kv, kv,
bias, bias,
mask, mask,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) @partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9))
def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, def _fused_attn_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, seed: jnp.ndarray, attn_bias_type: AttnBiasType,
scaling_factor: float, dropout_probability: float, is_training: bool): attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
output, _ = _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training) output, _ = _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability,
is_training)
return output return output
def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type, def _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training): scaling_factor, dropout_probability, is_training):
if mask is None: if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
batch, s_q, *_ = q.shape batch, s_q, *_ = q.shape
s_kv = kv.shape[1] s_kv = kv.shape[1]
q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32) q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32) kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
else: else:
assert mask is not None
mask = jnp.logical_not(mask) mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: if attn_mask_type == AttnMaskType.PADDING_MASK:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,) kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else: else:
# When mask is causal, the actual seqlen is not the last row, use max to find it # When mask is causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2)) kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = cross_fused_attn_fwd(q, output, softmax_aux, rng_state = fused_attn_fwd_kvpacked(
kv, q,
bias, kv,
q_actual_seqlen, bias,
kv_actual_seqlen, q_actual_seqlen,
seed, kv_actual_seqlen,
attn_bias_type=attn_bias_type.value, seed,
attn_mask_type=attn_mask_type.value, attn_bias_type=attn_bias_type.value,
scaling_factor=scaling_factor, attn_mask_type=attn_mask_type.value,
dropout_probability=dropout_probability, scaling_factor=scaling_factor,
is_training=is_training) dropout_probability=dropout_probability,
is_training=is_training)
output = checkpoint_name(output, 'context') output = checkpoint_name(output, 'context')
softmax_aux = checkpoint_name(softmax_aux, 'context') softmax_aux = checkpoint_name(softmax_aux, 'context')
rng_state = checkpoint_name(rng_state, 'context') rng_state = checkpoint_name(rng_state, 'context')
return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen) return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen)
def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, def _fused_attn_bwd_kvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor,
is_training, ctx, dz): dropout_probability, is_training, ctx, dz):
q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
grad_q, grad_kv, grad_bias = cross_fused_attn_bwd(q, grad_q, grad_kv, grad_bias = fused_attn_bwd_kvpacked(q,
kv, kv,
bias, bias,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
dz, dz,
q_actual_seqlen, q_actual_seqlen,
kv_actual_seqlen, kv_actual_seqlen,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
if attn_bias_type == AttnBiasType.NO_BIAS: if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None grad_bias = None
...@@ -239,7 +259,7 @@ def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, d ...@@ -239,7 +259,7 @@ def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, d
return grad_q, grad_kv, grad_bias, None, None return grad_q, grad_kv, grad_bias, None, None
_cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule) _fused_attn_kvpacked.defvjp(_fused_attn_fwd_kvpacked_rule, _fused_attn_bwd_kvpacked_rule)
def fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, def fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
...@@ -277,15 +297,16 @@ def _fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarra ...@@ -277,15 +297,16 @@ def _fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarra
def _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor, def _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training): dropout_probability, is_training):
if mask is None: if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
batch, s_q, *_ = q.shape batch, s_q, *_ = q.shape
s_kv = k.shape[1] s_kv = k.shape[1]
q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32) q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32) kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
else: else:
assert mask is not None
mask = jnp.logical_not(mask) mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: if attn_mask_type == AttnMaskType.PADDING_MASK:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,) kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else: else:
# When mask is causal, the actual seqlen is not the last row, use max to find it # When mask is causal, the actual seqlen is not the last row, use max to find it
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment