Unverified Commit 8e672ff0 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Refactor fused attention (#711)



* Remove unused headers
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Unify the fused attn workspace size cpp code
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Reduce the skipped cases
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Rename self/cross attention to qkvpacked/kvpacked
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update attention mask docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine the attn mask implementations
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent b855656b
......@@ -16,7 +16,7 @@ from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
DTYPES = [jnp.float16, jnp.bfloat16]
......@@ -86,7 +86,7 @@ class TestDistributedSelfAttn:
def target_func(qkv, bias, mask):
return jnp.mean(
self_fused_attn(qkv,
fused_attn_qkvpacked(qkv,
bias,
mask,
None,
......@@ -192,7 +192,7 @@ class TestDistributedCrossAttn:
def target_func(q, kv, mask):
return jnp.mean(
cross_fused_attn(q,
fused_attn_kvpacked(q,
kv,
None,
mask,
......
......@@ -2,8 +2,6 @@
#
# See LICENSE for license information.
"""Tests for fused attention"""
import sys
from enum import Enum
from dataclasses import dataclass
from functools import partial
......@@ -21,7 +19,7 @@ from jax import value_and_grad, jit
from jax.typing import ArrayLike, DTypeLike
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn, fused_attn
from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine_jax import NVTE_Fused_Attn_Backend
......@@ -144,11 +142,11 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
case QKVLayout.BS3HD:
query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
qkv = jnp.concatenate((query, key, value), axis=-3)
return self_fused_attn(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
return fused_attn_qkvpacked(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
case QKVLayout.BSHD_BS2HD:
key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
kv = jnp.concatenate((key, value), axis=-3)
return cross_fused_attn(query, kv, bias, mask, dropout_rng,
return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng,
**kwargs).astype(query.dtype)
case QKVLayout.BSHD_BSHD_BSHD:
return fused_attn(query, key, value, bias, mask, dropout_rng,
......@@ -156,6 +154,10 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
class BiasShape(Enum):
"""
Enum class to represent the different bias shapes used in the fused attention.
"""
BIAS_1HSS = '1HSS'
BIAS_B1SS = 'B1SS'
BIAS_BHSS = 'BHSS'
......@@ -188,17 +190,16 @@ class FusedAttnRunner:
if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")
self.backend = FusedAttnHelper(
self.dtype, self.dtype, self.qkv_layout.value, self.attn_bias_type.value,
self.attn_mask_type.value, self.dropout_prob, self.num_heads_q, self.num_heads_kv,
self.max_seqlen_q, self.max_seqlen_kv, self.head_dim).get_fused_attn_backend()
self.backend = FusedAttnHelper(self.dtype, self.dtype, self.qkv_layout.value,
self.attn_bias_type.value, self.attn_mask_type.value,
self.dropout_prob, self.num_heads_q, self.num_heads_kv,
self.max_seqlen_q, self.max_seqlen_kv,
self.head_dim).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("Unsupported inputs combination or device compute capability.")
if self.bias_shape != BiasShape.BIAS_1HSS:
if self.attn_bias_type != AttnBiasType.POST_SCALE_BIAS:
pytest.skip("B1SS, BHSS and 11SS bias shapes require POST_SCALE_BIAS.")
elif self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
"AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.")
elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
......@@ -213,7 +214,9 @@ class FusedAttnRunner:
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim)
if self.bias_shape == BiasShape.BIAS_1HSS:
if self.attn_bias_type == AttnBiasType.NO_BIAS:
bias_shape = None
elif self.bias_shape == BiasShape.BIAS_1HSS:
bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape.BIAS_B1SS:
bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
......@@ -222,7 +225,7 @@ class FusedAttnRunner:
elif self.bias_shape == BiasShape.BIAS_11SS:
bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
else:
pytest.xfail("PyTest attempted to use an unrecognized bias layout!")
pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.)
self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.)
......@@ -327,8 +330,8 @@ class FusedAttnRunner:
**kwargs), arg_nums))
jitted_reference = jit(
value_and_grad(
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args,
**kwargs), arg_nums))
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
arg_nums))
primitive_out, primitive_dgrad = jitted_primitive(*args)
reference_out, reference_dgrad = jitted_reference(*args)
......@@ -361,9 +364,9 @@ class FusedAttnRunner:
primitive_dbias = jnp.float32(primitive_dgrad[3])
reference_dbias = jnp.float32(reference_dgrad[3])
assert_allclose(
primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:]),
assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
jnp.zeros_like(primitive_dbias[..., self.valid_len_q:,
self.valid_len_kv:]),
dtype=self.dtype)
# dbias padded part
......@@ -376,15 +379,13 @@ class FusedAttnRunner:
reference_dbias[..., :self.valid_len_q, :self.valid_len_kv],
dtype=self.dtype)
@pytest.mark.parametrize('bias_shape', [
pytest.param(BiasShape.BIAS_1HSS, id='1-H-S-S'),
pytest.param(BiasShape.BIAS_B1SS, id='B-1-S-S'),
pytest.param(BiasShape.BIAS_BHSS, id='B-H-S-S'),
pytest.param(BiasShape.BIAS_11SS, id='1-1-S-S'),
])
@pytest.mark.parametrize('attn_bias_type', [
pytest.param(AttnBiasType.NO_BIAS, id='NO_BIAS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, id='POST_SCALE_BIAS'),
@pytest.mark.parametrize('attn_bias_type, bias_shape', [
pytest.param(AttnBiasType.NO_BIAS, None, id='NO_BIAS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id='POST_SCALE_BIAS-1HSS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id='POST_SCALE_BIAS-B1SS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id='POST_SCALE_BIAS-BHSS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id='POST_SCALE_BIAS-11SS'),
])
@pytest.mark.parametrize('attn_mask_type', [
pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'),
......@@ -399,31 +400,32 @@ class FusedAttnRunner:
])
@pytest.mark.parametrize('dtype', [
pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16")
pytest.param(jnp.float16, id="FP16"),
])
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d',[
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d', [
pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'),
pytest.param( 4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'),
pytest.param(4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'),
pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'),
pytest.param( 4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'),
pytest.param(4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'),
pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'),
pytest.param( 4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA')
pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA'),
])
@pytest.mark.parametrize('dropout_prob', [
pytest.param(0.0, id="DROP_0.0"),
pytest.param(0.1, id="DROP_0.1")
])
@pytest.mark.parametrize('is_training', [
pytest.param(True, id='TRAINING'),
pytest.param(False, id='INFERENCE'),
pytest.param(0.1, id="DROP_0.1"),
])
class TestFusedAttn:
"""
Fused attention tester
"""
@staticmethod
def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, is_training, qkv_layout, bias_shape):
@pytest.mark.parametrize('is_training', [
pytest.param(True, id='TRAINING'),
pytest.param(False, id='INFERENCE'),
])
def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
dtype, is_training, qkv_layout, bias_shape):
"""
Test forward with parameterized configs
"""
......@@ -432,13 +434,11 @@ class TestFusedAttn:
runner.test_forward()
@staticmethod
def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, is_training, qkv_layout, bias_shape):
def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
dtype, qkv_layout, bias_shape):
"""
Test backward with parameterized configs
"""
if not is_training:
pytest.skip("Backward pass does not support inference.")
runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, True, qkv_layout, bias_shape)
runner.test_backward()
......@@ -449,6 +449,7 @@ class TestDecoderLayer:
hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER,
self_attn_mask_type='padding_causal',
dtype=dtype,
**te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
......@@ -497,6 +498,7 @@ class TestDecoderLayer:
hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER,
self_attn_mask_type='padding_causal',
dtype=dtype,
**te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
......
......@@ -730,8 +730,13 @@ class TestDotProductAttn(TestLayer):
def input_getter(self, shape, dtype):
key = jax.random.PRNGKey(seed=1234)
q_key, k_key, v_key = jax.random.split(key, 3)
return list(map(partial(jax.random.normal, shape=shape, dtype=dtype),
[q_key, k_key, v_key]))
b, s, *_ = shape
if self.attrs[DotProductAttnAttr.TRANSPOSE_BS]:
b, s = s, b
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), mask
]
def get_layer_name(self):
return 'dot_product_attn'
......@@ -765,6 +770,7 @@ class TestDotProductAttn(TestLayer):
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
......@@ -853,9 +859,11 @@ class MultiHeadAttnAttr:
class TestMultiHeadAttn(TestLayer):
def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape,
dtype), jax.random.normal(data_key, shape, dtype))
key = jax.random.PRNGKey(seed=1234)
q_key, kv_key = jax.random.split(key, 2)
s, b, *_ = shape
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask]
def get_layer_name(self):
return 'multi_head_attn'
......@@ -1183,9 +1191,15 @@ class TransformerLayerAttr:
class TestTransformer(TestLayer):
def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape,
dtype), jax.random.normal(data_key, shape, dtype))
key = jax.random.PRNGKey(seed=1234)
q_key, kv_key = jax.random.split(key, 2)
b, s, *_ = shape
if self.attrs[TransformerLayerAttr.TRANSPOSE_BS]:
b, s = s, b
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask, mask
]
def get_layer_name(self):
return 'transformerlayer'
......@@ -1277,6 +1291,7 @@ class TestTransformer(TestLayer):
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
......@@ -1292,7 +1307,7 @@ class TestTransformer(TestLayer):
fp8_format,
rtol=1e-05,
atol=1e-08):
self.attrs = attrs
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
......
......@@ -1368,14 +1368,13 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
@staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return ScaledSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos
)
scale_factor, mesh, arg_infos, result_infos)
@staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos):
return ScaledSoftmaxFwdPrimitive.forward_partition(
ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)
return ScaledSoftmaxFwdPrimitive.forward_partition(ScaledSoftmaxFwdPrimitive.impl,
scale_factor, mesh, arg_infos,
result_infos)
register_primitive(ScaledSoftmaxFwdPrimitive)
......@@ -1444,14 +1443,13 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
@staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return ScaledSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos
)
scale_factor, mesh, arg_infos, result_infos)
@staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos):
return ScaledSoftmaxBwdPrimitive.backward_partition(
ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)
return ScaledSoftmaxBwdPrimitive.backward_partition(ScaledSoftmaxBwdPrimitive.impl,
scale_factor, mesh, arg_infos,
result_infos)
register_primitive(ScaledSoftmaxBwdPrimitive)
......@@ -1581,14 +1579,12 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
@staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return ScaledMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos,result_infos
)
scale_factor, mesh, arg_infos, result_infos)
@staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos):
return ScaledMaskedSoftmaxFwdPrimitive.backward_partition(
ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)
ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos)
register_primitive(ScaledMaskedSoftmaxFwdPrimitive)
......@@ -1660,14 +1656,12 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
@staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return ScaledMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos
)
scale_factor, mesh, arg_infos, result_infos)
@staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos):
return ScaledMaskedSoftmaxBwdPrimitive.backward_partition(
ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)
ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos)
register_primitive(ScaledMaskedSoftmaxBwdPrimitive)
......@@ -1749,15 +1743,13 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
@staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos
)
scale_factor, mesh, arg_infos, result_infos)
@staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos):
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_partition(
ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh,
arg_infos, result_infos
)
ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos,
result_infos)
register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)
......@@ -1829,15 +1821,13 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
@staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos
)
scale_factor, mesh, arg_infos, result_infos)
@staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos):
return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_partition(
ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh,
arg_infos, result_infos
)
ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos,
result_infos)
register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
......@@ -1859,16 +1849,16 @@ class FusedAttnHelper:
Helper for the fused attention backend
"""
q_type: jnp.dtype
kv_type: jnp.dtype
q_dtype: jnp.dtype
kv_dtype: jnp.dtype
qkv_layout: NVTE_QKV_Layout
attn_bias_type: NVTE_Bias_Type
attn_mask_type: NVTE_Mask_Type
dropout_probability: float
num_heads_q: int
num_heads_kv: int
max_seqlen_q: int
max_seqlen_kv: int
q_num_heads: int
kv_num_heads: int
q_max_seqlen: int
kv_max_seqlen: int
head_dim: int
def is_fused_attn_kernel_available(self):
......@@ -1878,11 +1868,38 @@ class FusedAttnHelper:
def get_fused_attn_backend(self):
"""Get the fused attention kernel backend"""
return transformer_engine_jax.get_fused_attn_backend(
jax_dtype_to_te_dtype(self.q_type), jax_dtype_to_te_dtype(self.kv_type),
jax_dtype_to_te_dtype(self.q_dtype), jax_dtype_to_te_dtype(self.kv_dtype),
self.qkv_layout, self.attn_bias_type, self.attn_mask_type, self.dropout_probability,
self.num_heads_q, self.num_heads_kv, self.max_seqlen_q, self.max_seqlen_kv,
self.q_num_heads, self.kv_num_heads, self.q_max_seqlen, self.kv_max_seqlen,
self.head_dim)
@staticmethod
def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout):
"""Parse qkv aval"""
match qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD:
*q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape
kv_batch_shape = q_batch_shape
kv_max_seqlen = q_max_seqlen
num_gqa_groups = attn_heads
kv_head_dim = q_head_dim
assert nqkv == 3
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape
assert nkv == 2
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
assert k_aval.shape == v_aval.shape
case _:
raise ValueError(f"Unexpected {qkv_layout=}")
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert q_aval.dtype == k_aval.dtype == v_aval.dtype
return (q_batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, q_head_dim)
@dataclass(frozen=True)
class _FusedAttnRNGStateChecker:
......@@ -1933,46 +1950,50 @@ def generate_cu_seqlen(actual_seqlen):
return cu_seqlen
class SelfFusedAttnFwdPrimitive(BasePrimitive):
class FusedAttnFwdPrimitive(BasePrimitive):
"""
Self Fused Attention Forward Primitive
Fused Attention Forward Primitive
"""
name = "te_self_fused_attn_forward"
name = "te_fused_attn_forward"
multiple_results = True
impl_static_args = (4, 5, 6, 7, 8)
impl_static_args = (7, 8, 9, 10, 11, 12)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(qkv_aval, bias_aval, seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
def abstract(q_aval, k_aval, v_aval, bias_aval, q_seqlen_or_cu_seqlen_aval,
kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type,
qkv_layout, scaling_factor, dropout_probability, is_training):
"""
Self fused attention fwd inner primitive abstract
Fused attention fwd abstract
"""
# outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
del seqlen_or_cu_seqlen_aval
qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
*input_batch_shape, max_seqlen, nqkv, attn_heads, head_dim = qkv_aval.shape
assert nqkv == 3
assert qkv_aval.dtype == bias_aval.dtype
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
assert q_dtype == k_dtype == v_dtype == bias_dtype
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
output_shape = (*input_batch_shape, max_seqlen, attn_heads, head_dim)
out_aval = qkv_aval.update(shape=output_shape, dtype=qkv_dtype)
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
output_shape = (*batch_shape, q_max_seqlen, attn_heads, head_dim)
out_aval = q_aval.update(shape=output_shape, dtype=q_dtype)
# backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type,
attn_mask_type, dropout_probability, attn_heads, attn_heads,
max_seqlen, max_seqlen, head_dim).get_fused_attn_backend()
backend = FusedAttnHelper(q_dtype, k_dtype, qkv_layout, attn_bias_type, attn_mask_type,
dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen,
kv_max_seqlen, head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_shape = (*input_batch_shape, attn_heads, max_seqlen, max_seqlen)
softmax_dtype = qkv_dtype
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
softmax_dtype = q_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_shape = (*input_batch_shape, attn_heads, max_seqlen, 1)
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f'Unsupported {backend=}')
softmax_aux_aval = qkv_aval.update(shape=softmax_shape, dtype=softmax_dtype)
softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)
# JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
# 32-bit unsigned int to get the buffer size we need in the C++ kernel
......@@ -1990,12 +2011,12 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
# prepare for the active fused-attn backend
input_batch = reduce(operator.mul, input_batch_shape)
wkspace_info = transformer_engine_jax.get_self_fused_attn_fwd_workspace_sizes(
input_batch, bias_batch, max_seqlen, attn_heads, bias_heads, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
wkspace_aval = qkv_aval.update(shape=wkspace_info[0],
input_batch = reduce(operator.mul, batch_shape)
wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes(
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
bias_heads, head_dim, scaling_factor, dropout_probability, attn_bias_type,
attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), is_training)
wkspace_aval = q_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval
......@@ -2003,19 +2024,19 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Self fused attention fwd outer primitive abstract
Fused attention fwd outer primitive abstract
"""
out_aval, softmax_aux_aval, rng_state_aval, _ = \
SelfFusedAttnFwdPrimitive.abstract(*args, **kwargs)
FusedAttnFwdPrimitive.abstract(*args, **kwargs)
return out_aval, softmax_aux_aval, rng_state_aval
@staticmethod
def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
def lowering(ctx, q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type,
attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training):
"""
Self fused attention fwd lowering rules
Fused attention fwd lowering rules
"""
operands = [qkv, bias, cu_seqlen, seed]
operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
......@@ -2023,9 +2044,12 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
qkv_aval, bias_aval, *_ = ctx.avals_in
*input_batch_shape, max_seqlen, _, attn_heads, head_dim = qkv_aval.shape
input_batch = reduce(operator.mul, input_batch_shape)
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
input_batch = reduce(operator.mul, batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
......@@ -2036,137 +2060,137 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch, bias_batch, max_seqlen, max_seqlen,
attn_heads, attn_heads, bias_heads, head_dim, wkspace_aval.size,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
bias_heads, head_dim, wkspace_aval.size, scaling_factor, dropout_probability,
attn_bias_type, attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype), is_training)
out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(qkv, bias, seqlen, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
assert SelfFusedAttnFwdPrimitive.inner_primitive is not None
def impl(q, k, v, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type, qkv_layout,
scaling_factor, dropout_probability, is_training):
assert FusedAttnFwdPrimitive.inner_primitive is not None
cu_seqlen = generate_cu_seqlen(seqlen)
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
output, softmax_aux, rng_state, _ = SelfFusedAttnFwdPrimitive.inner_primitive.bind(
qkv,
output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind(
q,
k,
v,
bias,
cu_seqlen,
q_cu_seqlen,
kv_cu_seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output, softmax_aux, rng_state
@staticmethod
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, qkv_layout,
scaling_factor, dropout_probability, is_training):
_check_valid_batch_dims(batch_dims)
assert SelfFusedAttnFwdPrimitive.outer_primitive is not None
qkv_bdim, _, _, seed_bdim = batch_dims
assert FusedAttnFwdPrimitive.outer_primitive is not None
q_bdim, *_, seed_bdim = batch_dims
out_bdims = qkv_bdim, qkv_bdim, seed_bdim
return SelfFusedAttnFwdPrimitive.outer_primitive.bind(
*batched_args,
out_bdims = q_bdim, q_bdim, seed_bdim
return FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training), out_bdims
@staticmethod
def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
def infer_sharding_from_operands(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor,
dropout_probability, is_training, mesh, arg_infos,
result_infos):
del attn_bias_type, attn_mask_type, scaling_factor
del dropout_probability, is_training, result_infos
x_spec = get_padded_spec(arg_infos[0]) # (...batch, seqlen, 3, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-3], *x_spec[-2:]))
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
match qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD:
# q_spec = (...batch, q_seqlen, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:]))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None))
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
# q_spec = (...batch, q_seqlen, head, hidden)
# k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-4]))
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
# q_spec = (...batch, q_seqlen, head, hidden)
# k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-4], x_spec[-2], x_spec[-4], None))
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3]))
case _:
raise ValueError(f"Unsupported {qkv_layout=}")
rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
return (out_sharding, softmax_aux_sharding, rng_state_sharding)
@staticmethod
def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0]) # (...batch, seqlen, 3, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-3], *x_spec[-2:]))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-4], x_spec[-2], x_spec[-4], None))
rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [rng_state_sharding])
def partition(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability,
is_training, mesh, arg_infos, result_infos):
out_sharding = result_infos[0].sharding
softmax_aux_sharding = result_infos[1].sharding
rng_state_sharding = seed_sharding = NamedSharding(mesh,
PartitionSpec(get_all_mesh_axes(), None))
arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
impl = partial(SelfFusedAttnFwdPrimitive.impl,
impl = partial(FusedAttnFwdPrimitive.impl,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return mesh, impl, out_shardings, arg_shardings
register_primitive(SelfFusedAttnFwdPrimitive)
def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray | None, seqlen: jnp.ndarray,
seed: jnp.ndarray | None, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
Wrapper for TE self fused attention fwd
Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2
"""
checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=qkv.dtype)
return SelfFusedAttnFwdPrimitive.outer_primitive.bind(qkv,
bias,
seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
register_primitive(FusedAttnFwdPrimitive)
class SelfFusedAttnBwdPrimitive(BasePrimitive):
class FusedAttnBwdPrimitive(BasePrimitive):
"""
Self Fused Attention Backward Primitive
Fused Attention Backward Primitive
"""
name = "te_self_fused_attn_backward"
name = "te_fused_attn_backward"
multiple_results = True
impl_static_args = (7, 8, 9, 10, 11)
impl_static_args = (10, 11, 12, 13, 14, 15)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(qkv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, doutput_aval,
seqlen_or_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
def abstract(q_aval, k_aval, v_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval,
doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type,
attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training):
"""
Self fused attention bwd abstract
Fused attention bwd abstract
"""
del softmax_aux_aval, rng_state_aval, seqlen_or_cu_seqlen_aval
del softmax_aux_aval, rng_state_aval, output_aval
assert qkv_aval.dtype == bias_aval.dtype == output_aval.dtype == doutput_aval.dtype
*input_batch_shape, max_seqlen, nqkv, attn_heads, head_dim = qkv_aval.shape
assert nqkv == 3
qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
......@@ -2174,46 +2198,55 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
input_batch = reduce(operator.mul, input_batch_shape)
input_batch = reduce(operator.mul, batch_shape)
wkspace_shape, wkspace_dtype = \
transformer_engine_jax.get_self_fused_attn_bwd_workspace_sizes(
input_batch, bias_batch, max_seqlen, attn_heads, bias_heads, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training
)
transformer_engine_jax.get_fused_attn_bwd_workspace_sizes(
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
bias_heads, head_dim, scaling_factor, dropout_probability, attn_bias_type,
attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), is_training)
dqkv_aval = qkv_aval.update(shape=qkv_aval.shape, dtype=qkv_dtype)
dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype)
dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype)
dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
wkspace_aval = qkv_aval.update(shape=wkspace_shape,
wkspace_aval = q_aval.update(shape=wkspace_shape,
dtype=te_dtype_to_jax_dtype(wkspace_dtype))
return dqkv_aval, dbias_aval, wkspace_aval
return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Self fused attention bwd outer primitive abstract
Fused attention fwd outer primitive abstract
"""
dqkv_aval, dbias_aval, _ = SelfFusedAttnBwdPrimitive.abstract(*args, **kwargs)
return dqkv_aval, dbias_aval
dq_aval, dk_aval, dv_aval, dbias_aval, _ = \
FusedAttnBwdPrimitive.abstract(*args, **kwargs)
return dq_aval, dk_aval, dv_aval, dbias_aval
@staticmethod
def lowering(ctx, qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen, *,
attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training):
def lowering(ctx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
kv_cu_seqlen, *, attn_bias_type, attn_mask_type, qkv_layout, scaling_factor,
dropout_probability, is_training):
"""
Self fused attention bwd lowering rules
Fused attention bwd lowering rules
"""
operands = [qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen]
operands = [
q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
qkv_aval, bias_aval, *_ = ctx.avals_in
*input_batch_shape, max_seqlen, _, attn_heads, head_dim = qkv_aval.shape
input_batch = reduce(operator.mul, input_batch_shape)
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
input_batch = reduce(operator.mul, batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
......@@ -2224,103 +2257,152 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch, bias_batch, max_seqlen, max_seqlen,
attn_heads, attn_heads, bias_heads, head_dim, wkspace_aval.size,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
bias_heads, head_dim, wkspace_aval.size, scaling_factor, dropout_probability,
attn_bias_type, attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype), is_training)
out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(qkv, bias, softmax_aux, rng_state, output, doutput, seqlen, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
assert SelfFusedAttnBwdPrimitive.inner_primitive is not None
def impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen,
attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability,
is_training):
assert FusedAttnBwdPrimitive.inner_primitive is not None
cu_seqlen = generate_cu_seqlen(seqlen)
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
dqkv, dbias, _ = SelfFusedAttnBwdPrimitive.inner_primitive.bind(
qkv,
dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
cu_seqlen,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return dqkv, dbias
return dq, dk, dv, dbias
@staticmethod
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, qkv_layout,
scaling_factor, dropout_probability, is_training):
_check_valid_batch_dims(batch_dims)
assert SelfFusedAttnBwdPrimitive.outer_primitive is not None
qkv_bdim, *_ = batch_dims
assert FusedAttnBwdPrimitive.outer_primitive is not None
q_bdim, k_bdim, v_bdim, *_ = batch_dims
out_bdims = qkv_bdim, qkv_bdim
return SelfFusedAttnBwdPrimitive.outer_primitive.bind(
*batched_args,
out_bdims = q_bdim, k_bdim, v_bdim, q_bdim
return FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training), out_bdims
@staticmethod
def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
def infer_sharding_from_operands(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor,
dropout_probability, is_training, mesh, arg_infos,
result_infos):
del attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
del is_training, result_infos
x_spec = get_padded_spec(arg_infos[0])
bias_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
del attn_bias_type, attn_mask_type, qkv_layout, scaling_factor
del dropout_probability, is_training, result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
return (dx_sharding, dbias_sharding)
return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
@staticmethod
def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
mesh, arg_infos, result_infos):
def partition(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability,
is_training, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
bias_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (dx_sharding, dbias_sharding)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
def sharded_impl(qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen):
local_dx, local_dbias = SelfFusedAttnBwdPrimitive.impl(
qkv,
def sharded_impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
kv_cu_seqlen):
local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
cu_seqlen,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
global_dbias = local_dbias
if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
return local_dx, global_dbias
return local_dq, local_dk, local_dv, global_dbias
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(SelfFusedAttnBwdPrimitive)
register_primitive(FusedAttnBwdPrimitive)
def fused_attn_fwd_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray, seqlen: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
Wrapper for TE self fused attention fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
"""
checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=qkv.dtype)
_not_used = jnp.zeros(0, qkv.dtype)
return FusedAttnFwdPrimitive.outer_primitive.bind(qkv,
_not_used,
_not_used,
bias,
seqlen,
seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BS3HD,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray,
def fused_attn_bwd_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray,
rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray,
seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
......@@ -2332,672 +2414,88 @@ def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.nd
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=qkv.dtype)
return SelfFusedAttnBwdPrimitive.outer_primitive.bind(qkv,
dummy_input = jnp.zeros(0, dtype=qkv.dtype)
dqkv, *_, dbias = FusedAttnBwdPrimitive.outer_primitive.bind(
qkv,
dummy_input,
dummy_input,
bias,
softmax_aux,
rng_state,
output,
doutput,
seqlen,
seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BS3HD,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return dqkv, dbias
class CrossFusedAttnFwdPrimitive(BasePrimitive):
def fused_attn_fwd_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Cross Fused Attention Forward Primitive
Wrapper for TE fused attention fwd with kvpacked inputs
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
"""
name = "te_cross_fused_attn_forward"
multiple_results = True
impl_static_args = (6, 7, 8, 9, 10)
inner_primitive = None
outer_primitive = None
checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
@staticmethod
def abstract(q_aval, kv_aval, bias_aval, q_seqlen_or_cu_seqlen_aval,
kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
return FusedAttnFwdPrimitive.outer_primitive.bind(q,
kv,
jnp.zeros(0, q.dtype),
bias,
q_seqlen,
kv_seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
def fused_attn_bwd_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray,
doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Cross fused attention fwd abstract
Wrapper for TE fused attention bwd with kvpacked inputs
Return the gradients of fused attention with packed kv input
"""
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
assert q_dtype == kv_dtype == bias_dtype
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = kv_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert nkv == 2
out_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
# backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
attn_bias_type, attn_mask_type, dropout_probability, attn_heads,
num_gqa_groups, q_max_seqlen, kv_max_seqlen,
q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
softmax_dtype = q_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f'Unsupported {backend=}')
softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)
# JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
# 32-bit unsigned int to get the buffer size we need in the C++ kernel
checker = _FusedAttnRNGStateChecker()
seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
assert seed_dtype == checker.rng_state_dtype
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
# prepare for the active fused-attn backend
input_batch = reduce(operator.mul, q_batch_shape)
wkspace_info = transformer_engine_jax.get_cross_fused_attn_fwd_workspace_sizes(
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
wkspace_aval = q_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Cross fused attention fwd outer primitive abstract
"""
out_aval, softmax_aux_aval, rng_state_aval, _ = \
CrossFusedAttnFwdPrimitive.abstract(*args, **kwargs)
return out_aval, softmax_aux_aval, rng_state_aval
@staticmethod
def lowering(ctx, q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Cross fused attention fwd lowering rules
"""
operands = [q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, kv_aval, bias_aval, *_ = ctx.avals_in
*input_batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, _, num_gqa_groups, _ = kv_aval.shape
input_batch = reduce(operator.mul, input_batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, head_dim,
wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(CrossFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(q, kv, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
assert CrossFusedAttnFwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
output, softmax_aux, rng_state, _ = CrossFusedAttnFwdPrimitive.inner_primitive.bind(
q,
kv,
bias,
q_cu_seqlen,
kv_cu_seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output, softmax_aux, rng_state
@staticmethod
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
_check_valid_batch_dims(batch_dims)
assert CrossFusedAttnFwdPrimitive.outer_primitive is not None
q_bdim, *_, seed_bdim = batch_dims
out_bdims = q_bdim, q_bdim, seed_bdim
return CrossFusedAttnFwdPrimitive.outer_primitive.bind(
*batched_args,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training), out_bdims
@staticmethod
def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, mesh, arg_infos,
result_infos):
del attn_bias_type, attn_mask_type, scaling_factor
del dropout_probability, is_training, result_infos
q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden)
kv_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, 2, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4]))
rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
return (out_sharding, softmax_aux_sharding, rng_state_sharding)
@staticmethod
def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
mesh, arg_infos, result_infos):
del result_infos
q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden)
kv_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, 2, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4]))
rng_state_sharding = seed_sharding = NamedSharding(mesh,
PartitionSpec(get_all_mesh_axes(), None))
arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
impl = partial(CrossFusedAttnFwdPrimitive.impl,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return mesh, impl, out_shardings, arg_shardings
register_primitive(CrossFusedAttnFwdPrimitive)
def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, q_seqlen: jnp.ndarray,
kv_seqlen: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
Wrapper for TE cross fused attention fwd
Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2
"""
checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
return CrossFusedAttnFwdPrimitive.outer_primitive.bind(q,
kv,
bias,
q_seqlen,
kv_seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
class CrossFusedAttnBwdPrimitive(BasePrimitive):
"""
Cross Fused Attention Backward Primitive
"""
name = "te_cross_fused_attn_backward"
multiple_results = True
impl_static_args = (9, 10, 11, 12, 13)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(q_aval, kv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval,
doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Cross fused attention bwd abstract
"""
del softmax_aux_aval, rng_state_aval, output_aval
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
assert q_dtype == kv_dtype == bias_dtype == doutput_dtype
assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = kv_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert nkv == 2
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
input_batch = reduce(operator.mul, q_batch_shape)
wkspace_shape, wkspace_dtype = \
transformer_engine_jax.get_cross_fused_attn_bwd_workspace_sizes(
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training
)
dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
dkv_aval = kv_aval.update(shape=kv_aval.shape, dtype=kv_dtype)
dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
wkspace_aval = q_aval.update(shape=wkspace_shape,
dtype=te_dtype_to_jax_dtype(wkspace_dtype))
return dq_aval, dkv_aval, dbias_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Cross fused attention fwd outer primitive abstract
"""
dq_aval, dkv_aval, dbias_aval, _ = \
CrossFusedAttnBwdPrimitive.abstract(*args, **kwargs)
return dq_aval, dkv_aval, dbias_aval
@staticmethod
def lowering(ctx, q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
kv_cu_seqlen, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
"""
Cross fused attention bwd lowering rules
"""
operands = [q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, kv_aval, bias_aval, *_ = ctx.avals_in
*input_batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, _, num_gqa_groups, _ = kv_aval.shape
input_batch = reduce(operator.mul, input_batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, head_dim,
wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(CrossFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen,
attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training):
assert CrossFusedAttnBwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
dq, dkv, dbias, _ = CrossFusedAttnBwdPrimitive.inner_primitive.bind(
q,
kv,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return dq, dkv, dbias
@staticmethod
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
_check_valid_batch_dims(batch_dims)
assert CrossFusedAttnBwdPrimitive.outer_primitive is not None
q_bdim, kv_bdim, *_ = batch_dims
out_bdims = q_bdim, kv_bdim, q_bdim
return CrossFusedAttnBwdPrimitive.outer_primitive.bind(
*batched_args,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training), out_bdims
@staticmethod
def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, mesh, arg_infos,
result_infos):
del attn_bias_type, attn_mask_type, scaling_factor
del dropout_probability, is_training, result_infos
q_spec = get_padded_spec(arg_infos[0])
kv_spec = get_padded_spec(arg_infos[1])
bias_spec = get_padded_spec(arg_infos[2])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
return (dq_sharding, dkv_sharding, dbias_sharding)
@staticmethod
def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
mesh, arg_infos, result_infos):
del result_infos
q_spec = get_padded_spec(arg_infos[0])
kv_spec = get_padded_spec(arg_infos[1])
bias_spec = get_padded_spec(arg_infos[2])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (dq_sharding, dkv_sharding, dbias_sharding)
def sharded_impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
kv_cu_seqlen):
local_dq, local_dkv, local_dbias = CrossFusedAttnBwdPrimitive.impl(
q,
kv,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
global_dbias = local_dbias
if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
return local_dq, local_dkv, global_dbias
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(CrossFusedAttnBwdPrimitive)
def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray,
doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Wrapper for TE cross fused attention bwd
Return the gradients of cross fused attention with packed kv input
"""
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
return CrossFusedAttnBwdPrimitive.outer_primitive.bind(q,
kv,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
class FusedAttnFwdPrimitive(BasePrimitive):
"""
Fused Attention Forward Primitive
Query, key, value are seperated tensors
"""
name = "te_fused_attn_forward"
multiple_results = True
impl_static_args = (7, 8, 9, 10, 11)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(q_aval, k_aval, v_aval, bias_aval, q_seqlen_or_cu_seqlen_aval,
kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
"""
Fused attention fwd abstract
"""
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
assert q_dtype == k_dtype == v_dtype == bias_dtype
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert k_aval.shape == v_aval.shape
out_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
# backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(q_dtype, k_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD, attn_bias_type,
attn_mask_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
softmax_dtype = q_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f'Unsupported {backend=}')
softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)
# JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
# 32-bit unsigned int to get the buffer size we need in the C++ kernel
checker = _FusedAttnRNGStateChecker()
seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
assert seed_dtype == checker.rng_state_dtype
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
# prepare for the active fused-attn backend
input_batch = reduce(operator.mul, q_batch_shape)
wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes(
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
wkspace_aval = q_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Fused attention fwd outer primitive abstract
"""
out_aval, softmax_aux_aval, rng_state_aval, _ = \
FusedAttnFwdPrimitive.abstract(*args, **kwargs)
return out_aval, softmax_aux_aval, rng_state_aval
@staticmethod
def lowering(ctx, q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Fused attention fwd lowering rules
"""
operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
*batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, num_gqa_groups, _ = k_aval.shape
assert k_aval.shape == v_aval.shape
input_batch = reduce(operator.mul, batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, head_dim,
wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(q, k, v, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
assert FusedAttnFwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind(
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
dummy_input = jnp.zeros(0, q.dtype)
dq, dkv, _, dbias = FusedAttnBwdPrimitive.outer_primitive.bind(
q,
k,
v,
kv,
dummy_input,
bias,
q_cu_seqlen,
kv_cu_seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output, softmax_aux, rng_state
@staticmethod
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
_check_valid_batch_dims(batch_dims)
assert FusedAttnFwdPrimitive.outer_primitive is not None
q_bdim, *_, seed_bdim = batch_dims
out_bdims = q_bdim, q_bdim, seed_bdim
return FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training), out_bdims
@staticmethod
def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, mesh, arg_infos,
result_infos):
del attn_bias_type, attn_mask_type, scaling_factor
del dropout_probability, is_training, result_infos
q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden)
k_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3]))
rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
return (out_sharding, softmax_aux_sharding, rng_state_sharding)
@staticmethod
def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
mesh, arg_infos, result_infos):
del result_infos
q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden)
k_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3]))
rng_state_sharding = seed_sharding = NamedSharding(mesh,
PartitionSpec(get_all_mesh_axes(), None))
arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
impl = partial(FusedAttnFwdPrimitive.impl,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return mesh, impl, out_shardings, arg_shardings
register_primitive(FusedAttnFwdPrimitive)
return dq, dkv, dbias
def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray,
......@@ -3006,7 +2504,7 @@ def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Wrapper for TE fused attention fwd, where query, key, value are seperated tensors
Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
"""
checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
......@@ -3015,7 +2513,8 @@ def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
return FusedAttnFwdPrimitive.outer_primitive.bind(q,
return FusedAttnFwdPrimitive.outer_primitive.bind(
q,
k,
v,
bias,
......@@ -3024,221 +2523,12 @@ def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
class FusedAttnBwdPrimitive(BasePrimitive):
"""
Fused Attention Backward Primitive
"""
name = "te_fused_attn_backward"
multiple_results = True
impl_static_args = (10, 11, 12, 13, 14)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(q_aval, k_aval, v_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval,
doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Fused attention bwd abstract
"""
del softmax_aux_aval, rng_state_aval, output_aval
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert k_aval.shape == v_aval.shape
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
input_batch = reduce(operator.mul, q_batch_shape)
wkspace_shape, wkspace_dtype = \
transformer_engine_jax.get_fused_attn_bwd_workspace_sizes(
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training
)
dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype)
dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype)
dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
wkspace_aval = q_aval.update(shape=wkspace_shape,
dtype=te_dtype_to_jax_dtype(wkspace_dtype))
return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Fused attention fwd outer primitive abstract
"""
dq_aval, dk_aval, dv_aval, dbias_aval, _ = \
FusedAttnBwdPrimitive.abstract(*args, **kwargs)
return dq_aval, dk_aval, dv_aval, dbias_aval
@staticmethod
def lowering(ctx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
kv_cu_seqlen, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
"""
Fused attention bwd lowering rules
"""
operands = [
q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
*batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, num_gqa_groups, _ = k_aval.shape
assert k_aval.shape == v_aval.shape
input_batch = reduce(operator.mul, batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, head_dim,
wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen,
attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training):
assert FusedAttnBwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return dq, dk, dv, dbias
@staticmethod
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
_check_valid_batch_dims(batch_dims)
assert FusedAttnBwdPrimitive.outer_primitive is not None
q_bdim, k_bdim, v_bdim, *_ = batch_dims
out_bdims = q_bdim, k_bdim, v_bdim, q_bdim
return FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training), out_bdims
@staticmethod
def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, mesh, arg_infos,
result_infos):
del attn_bias_type, attn_mask_type, scaling_factor
del dropout_probability, is_training, result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
@staticmethod
def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
mesh, arg_infos, result_infos):
del result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
def sharded_impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
kv_cu_seqlen):
local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
global_dbias = local_dbias
if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
return local_dq, local_dk, local_dv, global_dbias
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(FusedAttnBwdPrimitive)
def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray,
softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray,
doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray,
......@@ -3251,8 +2541,8 @@ def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
return FusedAttnBwdPrimitive.outer_primitive.bind(q,
return FusedAttnBwdPrimitive.outer_primitive.bind(
q,
k,
v,
bias,
......@@ -3264,6 +2554,7 @@ def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda
kv_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
......
......@@ -49,10 +49,6 @@ pybind11::dict Registrations() {
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward);
dict["te_scaled_upper_triang_masked_softmax_backward"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward);
dict["te_self_fused_attn_forward"] = EncapsulateFunction(SelfFusedAttnForward);
dict["te_self_fused_attn_backward"] = EncapsulateFunction(SelfFusedAttnBackward);
dict["te_cross_fused_attn_forward"] = EncapsulateFunction(CrossFusedAttnForward);
dict["te_cross_fused_attn_backward"] = EncapsulateFunction(CrossFusedAttnBackward);
dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward);
dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward);
return dict;
......@@ -72,10 +68,6 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes);
m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes);
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
m.def("get_self_fused_attn_fwd_workspace_sizes", &GetSelfFusedAttnForwardWorkspaceSizes);
m.def("get_self_fused_attn_bwd_workspace_sizes", &GetSelfFusedAttnBackwardWorkspaceSizes);
m.def("get_cross_fused_attn_fwd_workspace_sizes", &GetCrossFusedAttnForwardWorkspaceSizes);
m.def("get_cross_fused_attn_bwd_workspace_sizes", &GetCrossFusedAttnBackwardWorkspaceSizes);
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
......
......@@ -11,14 +11,12 @@
#include <cuda_runtime_api.h>
#include <cudnn.h>
#include <functional>
#include <numeric>
#include <stdexcept>
#include <string>
#include <vector>
#include "common/common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
#include "transformer_engine/activation.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/fused_attn.h"
......@@ -96,13 +94,13 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t paddin
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t wkspace_size, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
DType dtype, DType wkspace_dtype, bool is_training) {
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
bias_heads, head_dim, wkspace_size, scaling_factor, dropout_probability,
bias_type, mask_type, dtype, wkspace_dtype, is_training});
bias_heads, head_dim, wkspace_size, scaling_factor, dropout_probability, bias_type,
mask_type, qkv_layout, dtype, wkspace_dtype, is_training});
}
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
......@@ -942,12 +940,12 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads,
size_t q_attn_heads, size_t kv_attn_heads,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim) {
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_num_heads, kv_num_heads, q_max_seqlen, kv_max_seqlen,
mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen,
head_dim);
return backend;
}
......@@ -1029,244 +1027,31 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack,
}
}
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t max_seqlen,
size_t attn_heads, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto qkv_shape = std::vector<size_t>{input_batch * max_seqlen, 3, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, max_seqlen, max_seqlen};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto o_tensor = TensorWrapper(
nullptr, std::vector<size_t>{input_batch, max_seqlen, attn_heads, head_dim}, dtype);
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, attn_heads, max_seqlen, max_seqlen, head_dim);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input buffers from XLA
void *qkv = buffers[0];
void *bias = buffers[1];
void *cu_seqlens = buffers[2];
void *seed = buffers[3];
// output buffers from XLA
void *output = buffers[4];
void *softmax_aux = buffers[5];
void *rng_state = buffers[6];
void *workspace = buffers[7];
// tensor sizes
auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch;
auto max_seqlen = descriptor.q_max_seqlen;
auto attn_heads = descriptor.attn_heads;
auto bias_heads = descriptor.bias_heads;
auto head_dim = descriptor.head_dim;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{input_batch * max_seqlen, 3, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, max_seqlen, max_seqlen};
// input tensors
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
auto o_tensor = TensorWrapper(
output, std::vector<size_t>{input_batch * max_seqlen, attn_heads, head_dim}, dtype);
// prep RNG state
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, attn_heads, max_seqlen, max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, max_seqlen, max_seqlen, backend, stream);
// auxiliary tensors (to be propagated to the backward pass later)
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
softmax_aux);
// cuDNN workspace
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), max_seqlen, descriptor.is_training,
descriptor.scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t max_seqlen,
size_t attn_heads, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto qkv_shape = std::vector<size_t>{input_batch * max_seqlen, 3, attn_heads, head_dim};
auto output_shape = std::vector<size_t>{input_batch * max_seqlen, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, max_seqlen, max_seqlen};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input buffers from XLA
void *qkv = buffers[0];
void *bias = buffers[1];
void *softmax_aux = buffers[2];
void *rng_state = buffers[3];
void *output = buffers[4];
void *doutput = buffers[5];
void *cu_seqlens = buffers[6];
// output buffers from XLA
void *dqkv = buffers[7];
void *dbias = buffers[8];
void *workspace = buffers[9];
// tensor sizes
auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch;
auto max_seqlen = descriptor.q_max_seqlen;
auto attn_heads = descriptor.attn_heads;
auto bias_heads = descriptor.bias_heads;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{input_batch * max_seqlen, 3, attn_heads, head_dim};
auto output_shape = std::vector<size_t>{input_batch * max_seqlen, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, max_seqlen, max_seqlen};
// input tensors
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
// auxiliary tensors (propagated from the forward pass)
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, attn_heads, max_seqlen, max_seqlen, head_dim);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias);
// cuDNN workspace
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_input_tensors);
}
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training) {
// For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
// For kv_packed
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// For separate q, k, v
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_shape = k_shape;
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
// FP16/BF16 doesn't use this tensor
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);
......@@ -1281,292 +1066,133 @@ pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
assert(q_max_seqlen == kv_max_seqlen);
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_rng_state_tensor.data(),
q_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, query_workspace_tensor.data(), nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input buffers from XLA
void *q = buffers[0];
void *kv = buffers[1];
void *bias = buffers[2];
void *q_cu_seqlens = buffers[3];
void *kv_cu_seqlens = buffers[4];
void *seed = buffers[5];
// output buffers from XLA
void *output = buffers[6];
void *softmax_aux = buffers[7];
void *rng_state = buffers[8];
void *workspace = buffers[9];
// tensor sizes
auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto attn_heads = descriptor.attn_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto bias_heads = descriptor.bias_heads;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
auto o_tensor = TensorWrapper(output, q_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
// prep RNG state
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
// auxiliary tensors (to be propagated to the backward pass later)
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
softmax_aux);
// cuDNN workspace
auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
descriptor.wkspace_dtype);
nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(),
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.is_training, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
// FP16/BF16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
}
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
}
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input buffers from XLA
void *q = buffers[0];
void *kv = buffers[1];
void *bias = buffers[2];
void *softmax_aux = buffers[3];
void *rng_state = buffers[4];
void *output = buffers[5];
void *doutput = buffers[6];
void *q_cu_seqlens = buffers[7];
void *kv_cu_seqlens = buffers[8];
// output buffers from XLA
void *dq = buffers[9];
void *dkv = buffers[10];
void *dbias = buffers[11];
void *workspace = buffers[12];
// tensor sizes
auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto attn_heads = descriptor.attn_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto bias_heads = descriptor.bias_heads;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype,
bool is_training) {
auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto bias_shape = std::vector<size_t>{1, attn_heads, q_max_seqlen, kv_max_seqlen};
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
// input tensors
auto dtype = descriptor.dtype;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// F16 doesn't use s_tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// auxiliary tensors (propagated from the forward pass)
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias);
// cuDNN workspace
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
TensorWrapper query_workspace_tensor;
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
assert(q_max_seqlen == kv_max_seqlen);
auto qkv_shape = std::vector<size_t>{batch_size * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
nvte_fused_attn_bwd_qkvpacked(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape =
std::vector<size_t>{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_input_tensors);
}
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
query_workspace_tensor.data(), nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_shape = k_shape;
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(),
dv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
}
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
}
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input buffers from XLA
void *q = buffers[0];
void *k = buffers[1];
void *v = buffers[2];
/* Input buffers from XLA */
/* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
void *bias = buffers[3];
void *q_cu_seqlens = buffers[4];
void *kv_cu_seqlens = buffers[5];
void *seed = buffers[6];
// output buffers from XLA
/* Output buffer from XLA */
void *output = buffers[7];
void *softmax_aux = buffers[8];
void *rng_state = buffers[9];
void *workspace = buffers[10];
// tensor sizes
/* Descriptor */
auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch;
auto q_max_seqlen = descriptor.q_max_seqlen;
......@@ -1579,29 +1205,26 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto qkv_layout = descriptor.qkv_layout;
auto dtype = descriptor.dtype;
/* Input tensors */
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
// output tensors
/* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto o_tensor = TensorWrapper(output, q_shape, dtype);
auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto o_tensor = TensorWrapper(output, o_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
// prep RNG state
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
/* Prepare RNG state */
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
......@@ -1609,22 +1232,59 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
// auxiliary tensors (to be propagated to the backward pass later)
/* Auxiliary tensors (to be propagated to the backward pass later) */
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
softmax_aux);
// cuDNN workspace
/* cuDNN workspace */
auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
descriptor.wkspace_dtype);
/* Call the underly NVTE API */
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
auto qkv = buffers[0];
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), rng_state_tensor.data(), q_max_seqlen,
descriptor.is_training, descriptor.scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv = buffers[1];
auto kv_shape =
std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k = buffers[1];
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v = buffers[2];
auto v_shape = k_shape;
auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
nvte_tensor_pack_destroy(&aux_output_tensors);
}
......@@ -1632,10 +1292,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
......@@ -1682,10 +1340,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input buffers from XLA
void *q = buffers[0];
void *k = buffers[1];
void *v = buffers[2];
/* Input buffers from XLA */
/* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
void *bias = buffers[3];
void *softmax_aux = buffers[4];
void *rng_state = buffers[5];
......@@ -1694,14 +1350,12 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
void *q_cu_seqlens = buffers[8];
void *kv_cu_seqlens = buffers[9];
// output buffers from XLA
void *dq = buffers[10];
void *dk = buffers[11];
void *dv = buffers[12];
/* Output buffer from XLA */
/* Buffers[10-12] are dq, dk, dv, which are parsed later for different qkv_layout */
void *dbias = buffers[13];
void *workspace = buffers[14];
// tensor sizes
/* Descriptor */
auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch;
auto q_max_seqlen = descriptor.q_max_seqlen;
......@@ -1714,36 +1368,26 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto qkv_layout = descriptor.qkv_layout;
auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
/* Input tensors */
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// output tensors
/* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dk_tensor = TensorWrapper(dk, k_shape, dtype);
auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
// auxiliary tensors (propagated from the forward pass)
/* Auxiliary tensors (propagated from the forward pass) */
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
......@@ -1751,20 +1395,73 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias);
// cuDNN workspace
/* cuDNN workspace */
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
/* Call the underly NVTE API */
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
auto qkv = buffers[0];
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto dqkv = buffers[10];
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
nvte_fused_attn_bwd_qkvpacked(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv = buffers[1];
auto kv_shape =
std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
auto dq = buffers[10];
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dkv = buffers[11];
auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k = buffers[1];
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v = buffers[2];
auto v_shape = k_shape;
auto v_tensor = TensorWrapper(v, v_shape, dtype);
auto dq = buffers[10];
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dk = buffers[11];
auto dk_tensor = TensorWrapper(dk, k_shape, dtype);
auto dv = buffers[12];
auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(),
dv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
nvte_tensor_pack_destroy(&aux_input_tensors);
}
......
......@@ -118,17 +118,18 @@ struct CustomCallFusedAttnDescriptor {
float dropout_probability;
NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type;
NVTE_QKV_Layout qkv_layout;
DType dtype;
DType wkspace_dtype;
bool is_training;
};
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t wkspace_size, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
DType dtype, DType wkspace_dtype, bool is_training);
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......@@ -207,55 +208,19 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t max_seqlen,
size_t attn_heads, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t max_seqlen,
size_t attn_heads, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training);
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training);
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
......@@ -26,7 +26,7 @@ from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from ..fused_attn import is_fused_attn_kernel_available, canonicalize_attn_mask_type
from ..fused_attn import self_fused_attn, cross_fused_attn, fused_attn
from ..fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn
from ..softmax import SoftmaxType
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
......@@ -190,16 +190,19 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-publi
def convert_to_softmax_type(attn_mask_type, mask):
"""Convert the attn_mask_type to SoftmaxType"""
# mask is ignored for no_mask and causal_mask
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
mask = None
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]:
if mask is not None:
return SoftmaxType.SCALED_MASKED
return SoftmaxType.SCALED
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type = {'causal', 'padding'}")
return SoftmaxType.SCALED_MASKED, mask
return SoftmaxType.SCALED, mask
raise ValueError(f"Unsupported {attn_mask_type=}, supported attn_mask_type="
"{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}")
softmax_type = convert_to_softmax_type(self.attn_mask_type, mask)
softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask)
attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor)(attn_weights, mask,
......@@ -266,7 +269,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
qkv_packed = query
if self.transpose_batch_sequence:
qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
x = self_fused_attn(qkv_packed,
x = fused_attn_qkvpacked(qkv_packed,
bias,
mask,
seed,
......@@ -285,7 +288,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
if self.transpose_batch_sequence:
query = query.transpose([1, 0, 2, 3])
kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
x = cross_fused_attn(query,
x = fused_attn_kvpacked(query,
kv_packed,
bias,
mask,
......@@ -358,11 +361,27 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method
attention_dropout: float, default = 0.0
Dropout probability for the dropout op after the softmax.
attn_mask_type: str, default = 'causal'
Type of the attention mask passed into softmax operation in the self attention.
Available options: {'no_mask', 'padding', 'causal', 'causal_padding'}
Introduced in v0.10.0.
This parameter specifies the type of attention mask to be applied during the softmax
operation.
Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}
Each described below:
* no_mask: No attention mask is applied. This means the attention will consider the
full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
:attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.
.. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.
attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the self attention.
Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
When default is present, the type is automatically decided by the MHA's bias parameter.
Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used.
......@@ -438,6 +457,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method
mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out the attention softmax input.
:attr:`True` means to mask out the corresponding values.
Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
bias: jax.numpy.ndarray, default = None
A tensor used to shift attention softmax input.
*:
......@@ -639,9 +659,25 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attention_dropout: float, default = 0.0
Dropout probability for the dropout op after the softmax.
attn_mask_type: str, default = 'causal'
Type of the attention mask passed into softmax operation in the attention.
Available options: {'no_mask', 'padding', 'causal', 'causal_padding'}
Introduced in v0.10.0.
This parameter specifies the type of attention mask to be applied during the softmax
operation.
Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}
Each described below:
* no_mask: No attention mask is applied. This means the attention will consider the
full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
:attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.
.. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.
attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
......@@ -809,6 +845,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out the attention softmax input.
:attr:`True` means mask out the corresponding values.
Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
bias: jax.numpy.ndarray, default = None
A tensor used to shift the attention softmax input.
*
......@@ -1299,9 +1336,25 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
is added after self-attention.this can be used for structures like `T5`
Transformer in conjunction with the TransformerLayerType.ENCODER option.
self_attn_mask_type: str, default = 'causal'
Type of the attention mask passed into softmax operation in the self attention.
Available options: {'no_mask', 'padding', 'causal', 'causal_padding'}
Introduced in v0.10.0.
This parameter specifies the type of attention mask to be applied during the softmax
operation in the self attention.
Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}
Each described below:
* no_mask: No attention mask is applied. This means the self attention will consider the
full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
:attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.
.. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.
self_attn_bias_type: Optional[str], default = None
Type of the attention bias passed into the self attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
......@@ -1420,9 +1473,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
:attr:`layer_type=TransformerLayerType.DECODER`.
attention_mask : jax.numpy.ndarray, default = None
Boolean tensor used to mask out self-attention softmax input.
:attr:`True` means mask out the corresponding values.
Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'.
encoder_decoder_mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out cross-attention softmax input when
:attr:`layer_type=TransformerLayerType.DECODER`.
:attr:`True` means mask out the corresponding values.
deterministic: bool, default = False
Disable dropout layers if set to True.
decode: bool, default = False
......
......@@ -14,20 +14,29 @@ from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from .cpp_extensions import FusedAttnHelper
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
from .cpp_extensions import fused_attn_fwd_kvpacked, fused_attn_bwd_kvpacked
from .cpp_extensions import fused_attn_fwd_qkvpacked, fused_attn_bwd_qkvpacked
from .cpp_extensions import fused_attn_fwd, fused_attn_bwd
class AttnBiasType(Enum):
"""Attention Bias Type."""
"""
NO_BIAS: Softmax is performed as softmax(scale * qk)
PRE_SCALE_BIAS: Softmax is performed as softmax(scale * (qk + bias))
POST_SCALE_BIAS: Softmax is performed as softmax(scale * qk + bias)
"""
NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS
class AttnMaskType(Enum):
"""Attention Mask Type."""
"""
NO_MASK: No attention mask is applied.
PADDING_MASK: Indicates the presence of paddings at the end of each sequence.
CAUSAL_MASK: An upper triangular mask is applied to the softmax inputs.
PADDING_CAUSAL_MASK: A combination of both causal and padding masks.
"""
NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
......@@ -47,33 +56,38 @@ def canonicalize_attn_mask_type(attn_mask_type: str):
The overhead between padding and non-padding version should be small.
However, we will lease this limitation in the near feature.
"""
if attn_mask_type in ['causal', 'padding_causal']:
return AttnMaskType.PADDING_CAUSAL_MASK
if attn_mask_type in ['no_mask', 'padding']:
match attn_mask_type:
case 'no_mask':
return AttnMaskType.NO_MASK
case 'padding':
return AttnMaskType.PADDING_MASK
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type={'no_mask', 'padding', 'causal', 'padding_causal'}")
case 'causal':
return AttnMaskType.CAUSAL_MASK
case 'padding_causal' | 'causal_padding':
return AttnMaskType.PADDING_CAUSAL_MASK
raise ValueError(f"Unsupported {attn_mask_type=}, supported attn_mask_type="
"{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}")
def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type,
dropout_probability, num_heads_q, num_heads_kv, max_seqlen_q,
max_seqlen_kv, head_dim):
def is_fused_attn_kernel_available(q_dtype, kv_dtype, qkv_layout, attn_bias_type, attn_mask_type,
dropout_probability, q_num_heads, kv_num_heads, q_max_seqlen,
kv_max_seqlen, head_dim):
"""
To check whether the fused attention kernel is available
To check whether the fused attention kernel is supported
"""
return FusedAttnHelper(q_type, kv_type, qkv_layout.value, attn_bias_type.value,
attn_mask_type.value, dropout_probability, num_heads_q, num_heads_kv,
max_seqlen_q, max_seqlen_kv, head_dim).is_fused_attn_kernel_available()
return FusedAttnHelper(q_dtype, kv_dtype, qkv_layout.value, attn_bias_type.value,
attn_mask_type.value, dropout_probability, q_num_heads, kv_num_heads,
q_max_seqlen, kv_max_seqlen, head_dim).is_fused_attn_kernel_available()
def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
def fused_attn_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
Self fused attention wrapper
Fused attention with the qkvpacked inputs
"""
output = _self_fused_attn(qkv,
output = _fused_attn_qkvpacked(qkv,
bias,
mask,
seed,
......@@ -87,29 +101,30 @@ def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarra
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
def _fused_attn_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
output, _ = _self_fused_attn_fwd_rule(qkv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training)
output, _ = _fused_attn_fwd_qkvpacked_rule(qkv, bias, mask, seed, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability,
is_training)
return output
def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None,
mask: jnp.ndarray, seed: jnp.ndarray | None,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float,
is_training: bool):
if mask is None:
def _fused_attn_fwd_qkvpacked_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
batch, seqlen, *_ = qkv.shape
actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32)
else:
assert mask is not None
mask = jnp.logical_not(mask)
actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
output, softmax_aux, rng_state = fused_attn_fwd_qkvpacked(
qkv,
bias,
actual_seqlen,
seed,
......@@ -124,11 +139,11 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None,
return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen)
def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz):
def _fused_attn_bwd_qkvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, ctx, dz):
qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx
grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
grad_qkv, grad_bias = fused_attn_bwd_qkvpacked(qkv,
bias,
softmax_aux,
rng_state,
......@@ -147,17 +162,18 @@ def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dr
return grad_qkv, grad_bias, None, None
_self_fused_attn.defvjp(_self_fused_attn_fwd_rule, _self_fused_attn_bwd_rule)
_fused_attn_qkvpacked.defvjp(_fused_attn_fwd_qkvpacked_rule, _fused_attn_bwd_qkvpacked_rule)
def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
def fused_attn_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
Cross multi-head attention wrapper
Fused attention with the kvpacked inputs
"""
output = _cross_fused_attn(q,
output = _fused_attn_kvpacked(q,
kv,
bias,
mask,
......@@ -172,32 +188,36 @@ def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: j
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9))
def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
def _fused_attn_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
output, _ = _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training)
output, _ = _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability,
is_training)
return output
def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
def _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
if mask is None:
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
batch, s_q, *_ = q.shape
s_kv = kv.shape[1]
q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
else:
assert mask is not None
mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
if attn_mask_type == AttnMaskType.PADDING_MASK:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else:
# When mask is causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = cross_fused_attn_fwd(q,
output, softmax_aux, rng_state = fused_attn_fwd_kvpacked(
q,
kv,
bias,
q_actual_seqlen,
......@@ -214,11 +234,11 @@ def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mas
return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen)
def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz):
def _fused_attn_bwd_kvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, ctx, dz):
q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
grad_q, grad_kv, grad_bias = cross_fused_attn_bwd(q,
grad_q, grad_kv, grad_bias = fused_attn_bwd_kvpacked(q,
kv,
bias,
softmax_aux,
......@@ -239,7 +259,7 @@ def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, d
return grad_q, grad_kv, grad_bias, None, None
_cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule)
_fused_attn_kvpacked.defvjp(_fused_attn_fwd_kvpacked_rule, _fused_attn_bwd_kvpacked_rule)
def fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
......@@ -277,15 +297,16 @@ def _fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarra
def _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
if mask is None:
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
batch, s_q, *_ = q.shape
s_kv = k.shape[1]
q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
else:
assert mask is not None
mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
if attn_mask_type == AttnMaskType.PADDING_MASK:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else:
# When mask is causal, the actual seqlen is not the last row, use max to find it
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment