Unverified Commit 8e672ff0 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Refactor fused attention (#711)



* Remove unused headers
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Unify the fused attn workspace size cpp code
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Reduce the skipped cases
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Rename self/cross attention to qkvpacked/kvpacked
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update attention mask docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine the attn mask implementations
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent b855656b
...@@ -16,7 +16,7 @@ from distributed_test_base import compare_ops ...@@ -16,7 +16,7 @@ from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
DTYPES = [jnp.float16, jnp.bfloat16] DTYPES = [jnp.float16, jnp.bfloat16]
...@@ -86,7 +86,7 @@ class TestDistributedSelfAttn: ...@@ -86,7 +86,7 @@ class TestDistributedSelfAttn:
def target_func(qkv, bias, mask): def target_func(qkv, bias, mask):
return jnp.mean( return jnp.mean(
self_fused_attn(qkv, fused_attn_qkvpacked(qkv,
bias, bias,
mask, mask,
None, None,
...@@ -192,7 +192,7 @@ class TestDistributedCrossAttn: ...@@ -192,7 +192,7 @@ class TestDistributedCrossAttn:
def target_func(q, kv, mask): def target_func(q, kv, mask):
return jnp.mean( return jnp.mean(
cross_fused_attn(q, fused_attn_kvpacked(q,
kv, kv,
None, None,
mask, mask,
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Tests for fused attention""" """Tests for fused attention"""
import sys
from enum import Enum from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
...@@ -21,7 +19,7 @@ from jax import value_and_grad, jit ...@@ -21,7 +19,7 @@ from jax import value_and_grad, jit
from jax.typing import ArrayLike, DTypeLike from jax.typing import ArrayLike, DTypeLike
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn, fused_attn from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn
from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine_jax import NVTE_Fused_Attn_Backend from transformer_engine_jax import NVTE_Fused_Attn_Backend
...@@ -144,11 +142,11 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng ...@@ -144,11 +142,11 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
case QKVLayout.BS3HD: case QKVLayout.BS3HD:
query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value]) query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
qkv = jnp.concatenate((query, key, value), axis=-3) qkv = jnp.concatenate((query, key, value), axis=-3)
return self_fused_attn(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype) return fused_attn_qkvpacked(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
case QKVLayout.BSHD_BS2HD: case QKVLayout.BSHD_BS2HD:
key, value = map(partial(jnp.expand_dims, axis=-3), [key, value]) key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
kv = jnp.concatenate((key, value), axis=-3) kv = jnp.concatenate((key, value), axis=-3)
return cross_fused_attn(query, kv, bias, mask, dropout_rng, return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng,
**kwargs).astype(query.dtype) **kwargs).astype(query.dtype)
case QKVLayout.BSHD_BSHD_BSHD: case QKVLayout.BSHD_BSHD_BSHD:
return fused_attn(query, key, value, bias, mask, dropout_rng, return fused_attn(query, key, value, bias, mask, dropout_rng,
...@@ -156,6 +154,10 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng ...@@ -156,6 +154,10 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
class BiasShape(Enum): class BiasShape(Enum):
"""
Enum class to represent the different bias shapes used in the fused attention.
"""
BIAS_1HSS = '1HSS' BIAS_1HSS = '1HSS'
BIAS_B1SS = 'B1SS' BIAS_B1SS = 'B1SS'
BIAS_BHSS = 'BHSS' BIAS_BHSS = 'BHSS'
...@@ -188,17 +190,16 @@ class FusedAttnRunner: ...@@ -188,17 +190,16 @@ class FusedAttnRunner:
if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv: if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.") pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")
self.backend = FusedAttnHelper( self.backend = FusedAttnHelper(self.dtype, self.dtype, self.qkv_layout.value,
self.dtype, self.dtype, self.qkv_layout.value, self.attn_bias_type.value, self.attn_bias_type.value, self.attn_mask_type.value,
self.attn_mask_type.value, self.dropout_prob, self.num_heads_q, self.num_heads_kv, self.dropout_prob, self.num_heads_q, self.num_heads_kv,
self.max_seqlen_q, self.max_seqlen_kv, self.head_dim).get_fused_attn_backend() self.max_seqlen_q, self.max_seqlen_kv,
self.head_dim).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("Unsupported inputs combination or device compute capability.") pytest.skip("Unsupported inputs combination or device compute capability.")
if self.bias_shape != BiasShape.BIAS_1HSS: if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
if self.attn_bias_type != AttnBiasType.POST_SCALE_BIAS: if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pytest.skip("B1SS, BHSS and 11SS bias shapes require POST_SCALE_BIAS.")
elif self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for " pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
"AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.") "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.")
elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
...@@ -213,7 +214,9 @@ class FusedAttnRunner: ...@@ -213,7 +214,9 @@ class FusedAttnRunner:
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim) q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim) k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim)
if self.bias_shape == BiasShape.BIAS_1HSS: if self.attn_bias_type == AttnBiasType.NO_BIAS:
bias_shape = None
elif self.bias_shape == BiasShape.BIAS_1HSS:
bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv) bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape.BIAS_B1SS: elif self.bias_shape == BiasShape.BIAS_B1SS:
bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv) bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
...@@ -222,7 +225,7 @@ class FusedAttnRunner: ...@@ -222,7 +225,7 @@ class FusedAttnRunner:
elif self.bias_shape == BiasShape.BIAS_11SS: elif self.bias_shape == BiasShape.BIAS_11SS:
bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv) bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
else: else:
pytest.xfail("PyTest attempted to use an unrecognized bias layout!") pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.) self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.)
self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.) self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.)
...@@ -327,8 +330,8 @@ class FusedAttnRunner: ...@@ -327,8 +330,8 @@ class FusedAttnRunner:
**kwargs), arg_nums)) **kwargs), arg_nums))
jitted_reference = jit( jitted_reference = jit(
value_and_grad( value_and_grad(
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
**kwargs), arg_nums)) arg_nums))
primitive_out, primitive_dgrad = jitted_primitive(*args) primitive_out, primitive_dgrad = jitted_primitive(*args)
reference_out, reference_dgrad = jitted_reference(*args) reference_out, reference_dgrad = jitted_reference(*args)
...@@ -361,9 +364,9 @@ class FusedAttnRunner: ...@@ -361,9 +364,9 @@ class FusedAttnRunner:
primitive_dbias = jnp.float32(primitive_dgrad[3]) primitive_dbias = jnp.float32(primitive_dgrad[3])
reference_dbias = jnp.float32(reference_dgrad[3]) reference_dbias = jnp.float32(reference_dgrad[3])
assert_allclose( assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:], jnp.zeros_like(primitive_dbias[..., self.valid_len_q:,
jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:]), self.valid_len_kv:]),
dtype=self.dtype) dtype=self.dtype)
# dbias padded part # dbias padded part
...@@ -376,15 +379,13 @@ class FusedAttnRunner: ...@@ -376,15 +379,13 @@ class FusedAttnRunner:
reference_dbias[..., :self.valid_len_q, :self.valid_len_kv], reference_dbias[..., :self.valid_len_q, :self.valid_len_kv],
dtype=self.dtype) dtype=self.dtype)
@pytest.mark.parametrize('bias_shape', [
pytest.param(BiasShape.BIAS_1HSS, id='1-H-S-S'), @pytest.mark.parametrize('attn_bias_type, bias_shape', [
pytest.param(BiasShape.BIAS_B1SS, id='B-1-S-S'), pytest.param(AttnBiasType.NO_BIAS, None, id='NO_BIAS'),
pytest.param(BiasShape.BIAS_BHSS, id='B-H-S-S'), pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id='POST_SCALE_BIAS-1HSS'),
pytest.param(BiasShape.BIAS_11SS, id='1-1-S-S'), pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id='POST_SCALE_BIAS-B1SS'),
]) pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id='POST_SCALE_BIAS-BHSS'),
@pytest.mark.parametrize('attn_bias_type', [ pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id='POST_SCALE_BIAS-11SS'),
pytest.param(AttnBiasType.NO_BIAS, id='NO_BIAS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, id='POST_SCALE_BIAS'),
]) ])
@pytest.mark.parametrize('attn_mask_type', [ @pytest.mark.parametrize('attn_mask_type', [
pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'), pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'),
...@@ -399,31 +400,32 @@ class FusedAttnRunner: ...@@ -399,31 +400,32 @@ class FusedAttnRunner:
]) ])
@pytest.mark.parametrize('dtype', [ @pytest.mark.parametrize('dtype', [
pytest.param(jnp.bfloat16, id="BF16"), pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16") pytest.param(jnp.float16, id="FP16"),
]) ])
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d',[ @pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d', [
pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'), pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'),
pytest.param( 4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'), pytest.param(4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'),
pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'), pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'),
pytest.param( 4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'), pytest.param(4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'),
pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'), pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'),
pytest.param( 4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA') pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA'),
]) ])
@pytest.mark.parametrize('dropout_prob', [ @pytest.mark.parametrize('dropout_prob', [
pytest.param(0.0, id="DROP_0.0"), pytest.param(0.0, id="DROP_0.0"),
pytest.param(0.1, id="DROP_0.1") pytest.param(0.1, id="DROP_0.1"),
])
@pytest.mark.parametrize('is_training', [
pytest.param(True, id='TRAINING'),
pytest.param(False, id='INFERENCE'),
]) ])
class TestFusedAttn: class TestFusedAttn:
""" """
Fused attention tester Fused attention tester
""" """
@staticmethod @staticmethod
def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, @pytest.mark.parametrize('is_training', [
dropout_prob, dtype, is_training, qkv_layout, bias_shape): pytest.param(True, id='TRAINING'),
pytest.param(False, id='INFERENCE'),
])
def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
dtype, is_training, qkv_layout, bias_shape):
""" """
Test forward with parameterized configs Test forward with parameterized configs
""" """
...@@ -432,13 +434,11 @@ class TestFusedAttn: ...@@ -432,13 +434,11 @@ class TestFusedAttn:
runner.test_forward() runner.test_forward()
@staticmethod @staticmethod
def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
dropout_prob, dtype, is_training, qkv_layout, bias_shape): dtype, qkv_layout, bias_shape):
""" """
Test backward with parameterized configs Test backward with parameterized configs
""" """
if not is_training:
pytest.skip("Backward pass does not support inference.")
runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, True, qkv_layout, bias_shape) dropout_prob, dtype, True, qkv_layout, bias_shape)
runner.test_backward() runner.test_backward()
...@@ -449,6 +449,7 @@ class TestDecoderLayer: ...@@ -449,6 +449,7 @@ class TestDecoderLayer:
hidden_dropout_dims=(sequence_dim,), hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,), intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER, layer_type=TransformerLayerType.DECODER,
self_attn_mask_type='padding_causal',
dtype=dtype, dtype=dtype,
**te_layer_attrs) **te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs, ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
...@@ -497,6 +498,7 @@ class TestDecoderLayer: ...@@ -497,6 +498,7 @@ class TestDecoderLayer:
hidden_dropout_dims=(sequence_dim,), hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,), intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER, layer_type=TransformerLayerType.DECODER,
self_attn_mask_type='padding_causal',
dtype=dtype, dtype=dtype,
**te_layer_attrs) **te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs, ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
......
...@@ -730,8 +730,13 @@ class TestDotProductAttn(TestLayer): ...@@ -730,8 +730,13 @@ class TestDotProductAttn(TestLayer):
def input_getter(self, shape, dtype): def input_getter(self, shape, dtype):
key = jax.random.PRNGKey(seed=1234) key = jax.random.PRNGKey(seed=1234)
q_key, k_key, v_key = jax.random.split(key, 3) q_key, k_key, v_key = jax.random.split(key, 3)
return list(map(partial(jax.random.normal, shape=shape, dtype=dtype), b, s, *_ = shape
[q_key, k_key, v_key])) if self.attrs[DotProductAttnAttr.TRANSPOSE_BS]:
b, s = s, b
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), mask
]
def get_layer_name(self): def get_layer_name(self):
return 'dot_product_attn' return 'dot_product_attn'
...@@ -765,6 +770,7 @@ class TestDotProductAttn(TestLayer): ...@@ -765,6 +770,7 @@ class TestDotProductAttn(TestLayer):
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS) @pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
...@@ -853,9 +859,11 @@ class MultiHeadAttnAttr: ...@@ -853,9 +859,11 @@ class MultiHeadAttnAttr:
class TestMultiHeadAttn(TestLayer): class TestMultiHeadAttn(TestLayer):
def input_getter(self, shape, dtype): def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234) key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape, q_key, kv_key = jax.random.split(key, 2)
dtype), jax.random.normal(data_key, shape, dtype)) s, b, *_ = shape
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask]
def get_layer_name(self): def get_layer_name(self):
return 'multi_head_attn' return 'multi_head_attn'
...@@ -1183,9 +1191,15 @@ class TransformerLayerAttr: ...@@ -1183,9 +1191,15 @@ class TransformerLayerAttr:
class TestTransformer(TestLayer): class TestTransformer(TestLayer):
def input_getter(self, shape, dtype): def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234) key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape, q_key, kv_key = jax.random.split(key, 2)
dtype), jax.random.normal(data_key, shape, dtype)) b, s, *_ = shape
if self.attrs[TransformerLayerAttr.TRANSPOSE_BS]:
b, s = s, b
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask, mask
]
def get_layer_name(self): def get_layer_name(self):
return 'transformerlayer' return 'transformerlayer'
...@@ -1277,6 +1291,7 @@ class TestTransformer(TestLayer): ...@@ -1277,6 +1291,7 @@ class TestTransformer(TestLayer):
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS) @pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
...@@ -1292,7 +1307,7 @@ class TestTransformer(TestLayer): ...@@ -1292,7 +1307,7 @@ class TestTransformer(TestLayer):
fp8_format, fp8_format,
rtol=1e-05, rtol=1e-05,
atol=1e-08): atol=1e-08):
self.attrs = attrs
ds = DelayedScaling(fp8_format=fp8_format) ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
......
This diff is collapsed.
...@@ -49,10 +49,6 @@ pybind11::dict Registrations() { ...@@ -49,10 +49,6 @@ pybind11::dict Registrations() {
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward); EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward);
dict["te_scaled_upper_triang_masked_softmax_backward"] = dict["te_scaled_upper_triang_masked_softmax_backward"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward); EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward);
dict["te_self_fused_attn_forward"] = EncapsulateFunction(SelfFusedAttnForward);
dict["te_self_fused_attn_backward"] = EncapsulateFunction(SelfFusedAttnBackward);
dict["te_cross_fused_attn_forward"] = EncapsulateFunction(CrossFusedAttnForward);
dict["te_cross_fused_attn_backward"] = EncapsulateFunction(CrossFusedAttnBackward);
dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward);
dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward);
return dict; return dict;
...@@ -72,10 +68,6 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -72,10 +68,6 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes); m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes);
m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes); m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes);
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes); m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
m.def("get_self_fused_attn_fwd_workspace_sizes", &GetSelfFusedAttnForwardWorkspaceSizes);
m.def("get_self_fused_attn_bwd_workspace_sizes", &GetSelfFusedAttnBackwardWorkspaceSizes);
m.def("get_cross_fused_attn_fwd_workspace_sizes", &GetCrossFusedAttnForwardWorkspaceSizes);
m.def("get_cross_fused_attn_bwd_workspace_sizes", &GetCrossFusedAttnBackwardWorkspaceSizes);
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
......
This diff is collapsed.
...@@ -118,17 +118,18 @@ struct CustomCallFusedAttnDescriptor { ...@@ -118,17 +118,18 @@ struct CustomCallFusedAttnDescriptor {
float dropout_probability; float dropout_probability;
NVTE_Bias_Type bias_type; NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type; NVTE_Mask_Type mask_type;
NVTE_QKV_Layout qkv_layout;
DType dtype; DType dtype;
DType wkspace_dtype; DType wkspace_dtype;
bool is_training; bool is_training;
}; };
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t wkspace_size, float scaling_factor, float dropout_probability, size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
DType dtype, DType wkspace_dtype, bool is_training); bool is_training);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
...@@ -207,55 +208,19 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, ...@@ -207,55 +208,19 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len); std::size_t opaque_len);
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t max_seqlen,
size_t attn_heads, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t max_seqlen,
size_t attn_heads, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes( pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training);
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training);
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
...@@ -26,7 +26,7 @@ from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP ...@@ -26,7 +26,7 @@ from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax from .module import LayerNorm, Softmax
from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from ..fused_attn import is_fused_attn_kernel_available, canonicalize_attn_mask_type from ..fused_attn import is_fused_attn_kernel_available, canonicalize_attn_mask_type
from ..fused_attn import self_fused_attn, cross_fused_attn, fused_attn from ..fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
from ..sharding import num_of_devices from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
...@@ -190,16 +190,19 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-publi ...@@ -190,16 +190,19 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-publi
def convert_to_softmax_type(attn_mask_type, mask): def convert_to_softmax_type(attn_mask_type, mask):
"""Convert the attn_mask_type to SoftmaxType""" """Convert the attn_mask_type to SoftmaxType"""
# mask is ignored for no_mask and causal_mask
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
mask = None
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]: if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]:
if mask is not None: if mask is not None:
return SoftmaxType.SCALED_MASKED return SoftmaxType.SCALED_MASKED, mask
return SoftmaxType.SCALED return SoftmaxType.SCALED, mask
raise ValueError(f"Unsupported {attn_mask_type=}, " raise ValueError(f"Unsupported {attn_mask_type=}, supported attn_mask_type="
"supported attn_mask_type = {'causal', 'padding'}") "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}")
softmax_type = convert_to_softmax_type(self.attn_mask_type, mask) softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask)
attn_weights = Softmax(softmax_type=softmax_type, attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor)(attn_weights, mask, scale_factor=fused_scale_factor)(attn_weights, mask,
...@@ -266,7 +269,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -266,7 +269,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
qkv_packed = query qkv_packed = query
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4]) qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
x = self_fused_attn(qkv_packed, x = fused_attn_qkvpacked(qkv_packed,
bias, bias,
mask, mask,
seed, seed,
...@@ -285,7 +288,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -285,7 +288,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
query = query.transpose([1, 0, 2, 3]) query = query.transpose([1, 0, 2, 3])
kv_packed = kv_packed.transpose([1, 0, 2, 3, 4]) kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
x = cross_fused_attn(query, x = fused_attn_kvpacked(query,
kv_packed, kv_packed,
bias, bias,
mask, mask,
...@@ -358,11 +361,27 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method ...@@ -358,11 +361,27 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method
attention_dropout: float, default = 0.0 attention_dropout: float, default = 0.0
Dropout probability for the dropout op after the softmax. Dropout probability for the dropout op after the softmax.
attn_mask_type: str, default = 'causal' attn_mask_type: str, default = 'causal'
Type of the attention mask passed into softmax operation in the self attention. This parameter specifies the type of attention mask to be applied during the softmax
Available options: {'no_mask', 'padding', 'causal', 'causal_padding'} operation.
Introduced in v0.10.0. Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}
Each described below:
* no_mask: No attention mask is applied. This means the attention will consider the
full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
:attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.
.. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.
attn_bias_type: Optional[str], default = None attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the self attention. Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
When default is present, the type is automatically decided by the MHA's bias parameter. When default is present, the type is automatically decided by the MHA's bias parameter.
Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used. Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used.
...@@ -438,6 +457,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method ...@@ -438,6 +457,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method
mask: jax.numpy.ndarray, default = None mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out the attention softmax input. Boolean tensor used to mask out the attention softmax input.
:attr:`True` means to mask out the corresponding values. :attr:`True` means to mask out the corresponding values.
Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
bias: jax.numpy.ndarray, default = None bias: jax.numpy.ndarray, default = None
A tensor used to shift attention softmax input. A tensor used to shift attention softmax input.
*: *:
...@@ -639,9 +659,25 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -639,9 +659,25 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attention_dropout: float, default = 0.0 attention_dropout: float, default = 0.0
Dropout probability for the dropout op after the softmax. Dropout probability for the dropout op after the softmax.
attn_mask_type: str, default = 'causal' attn_mask_type: str, default = 'causal'
Type of the attention mask passed into softmax operation in the attention. This parameter specifies the type of attention mask to be applied during the softmax
Available options: {'no_mask', 'padding', 'causal', 'causal_padding'} operation.
Introduced in v0.10.0. Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}
Each described below:
* no_mask: No attention mask is applied. This means the attention will consider the
full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
:attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.
.. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.
attn_bias_type: Optional[str], default = None attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention. Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
...@@ -809,6 +845,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -809,6 +845,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
mask: jax.numpy.ndarray, default = None mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out the attention softmax input. Boolean tensor used to mask out the attention softmax input.
:attr:`True` means mask out the corresponding values. :attr:`True` means mask out the corresponding values.
Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
bias: jax.numpy.ndarray, default = None bias: jax.numpy.ndarray, default = None
A tensor used to shift the attention softmax input. A tensor used to shift the attention softmax input.
* *
...@@ -1299,9 +1336,25 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1299,9 +1336,25 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
is added after self-attention.this can be used for structures like `T5` is added after self-attention.this can be used for structures like `T5`
Transformer in conjunction with the TransformerLayerType.ENCODER option. Transformer in conjunction with the TransformerLayerType.ENCODER option.
self_attn_mask_type: str, default = 'causal' self_attn_mask_type: str, default = 'causal'
Type of the attention mask passed into softmax operation in the self attention. This parameter specifies the type of attention mask to be applied during the softmax
Available options: {'no_mask', 'padding', 'causal', 'causal_padding'} operation in the self attention.
Introduced in v0.10.0. Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}
Each described below:
* no_mask: No attention mask is applied. This means the self attention will consider the
full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
:attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.
.. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.
self_attn_bias_type: Optional[str], default = None self_attn_bias_type: Optional[str], default = None
Type of the attention bias passed into the self attention. Type of the attention bias passed into the self attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
...@@ -1420,9 +1473,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1420,9 +1473,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
:attr:`layer_type=TransformerLayerType.DECODER`. :attr:`layer_type=TransformerLayerType.DECODER`.
attention_mask : jax.numpy.ndarray, default = None attention_mask : jax.numpy.ndarray, default = None
Boolean tensor used to mask out self-attention softmax input. Boolean tensor used to mask out self-attention softmax input.
:attr:`True` means mask out the corresponding values.
Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'.
encoder_decoder_mask: jax.numpy.ndarray, default = None encoder_decoder_mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out cross-attention softmax input when Boolean tensor used to mask out cross-attention softmax input when
:attr:`layer_type=TransformerLayerType.DECODER`. :attr:`layer_type=TransformerLayerType.DECODER`.
:attr:`True` means mask out the corresponding values.
deterministic: bool, default = False deterministic: bool, default = False
Disable dropout layers if set to True. Disable dropout layers if set to True.
decode: bool, default = False decode: bool, default = False
......
...@@ -14,20 +14,29 @@ from transformer_engine_jax import NVTE_Mask_Type ...@@ -14,20 +14,29 @@ from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout from transformer_engine_jax import NVTE_QKV_Layout
from .cpp_extensions import FusedAttnHelper from .cpp_extensions import FusedAttnHelper
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd from .cpp_extensions import fused_attn_fwd_kvpacked, fused_attn_bwd_kvpacked
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd from .cpp_extensions import fused_attn_fwd_qkvpacked, fused_attn_bwd_qkvpacked
from .cpp_extensions import fused_attn_fwd, fused_attn_bwd from .cpp_extensions import fused_attn_fwd, fused_attn_bwd
class AttnBiasType(Enum): class AttnBiasType(Enum):
"""Attention Bias Type.""" """
NO_BIAS: Softmax is performed as softmax(scale * qk)
PRE_SCALE_BIAS: Softmax is performed as softmax(scale * (qk + bias))
POST_SCALE_BIAS: Softmax is performed as softmax(scale * qk + bias)
"""
NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS
class AttnMaskType(Enum): class AttnMaskType(Enum):
"""Attention Mask Type.""" """
NO_MASK: No attention mask is applied.
PADDING_MASK: Indicates the presence of paddings at the end of each sequence.
CAUSAL_MASK: An upper triangular mask is applied to the softmax inputs.
PADDING_CAUSAL_MASK: A combination of both causal and padding masks.
"""
NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
...@@ -47,33 +56,38 @@ def canonicalize_attn_mask_type(attn_mask_type: str): ...@@ -47,33 +56,38 @@ def canonicalize_attn_mask_type(attn_mask_type: str):
The overhead between padding and non-padding version should be small. The overhead between padding and non-padding version should be small.
However, we will lease this limitation in the near feature. However, we will lease this limitation in the near feature.
""" """
if attn_mask_type in ['causal', 'padding_causal']: match attn_mask_type:
return AttnMaskType.PADDING_CAUSAL_MASK case 'no_mask':
if attn_mask_type in ['no_mask', 'padding']: return AttnMaskType.NO_MASK
case 'padding':
return AttnMaskType.PADDING_MASK return AttnMaskType.PADDING_MASK
raise ValueError(f"Unsupported {attn_mask_type=}, " case 'causal':
"supported attn_mask_type={'no_mask', 'padding', 'causal', 'padding_causal'}") return AttnMaskType.CAUSAL_MASK
case 'padding_causal' | 'causal_padding':
return AttnMaskType.PADDING_CAUSAL_MASK
raise ValueError(f"Unsupported {attn_mask_type=}, supported attn_mask_type="
"{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}")
def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type, def is_fused_attn_kernel_available(q_dtype, kv_dtype, qkv_layout, attn_bias_type, attn_mask_type,
dropout_probability, num_heads_q, num_heads_kv, max_seqlen_q, dropout_probability, q_num_heads, kv_num_heads, q_max_seqlen,
max_seqlen_kv, head_dim): kv_max_seqlen, head_dim):
""" """
To check whether the fused attention kernel is available To check whether the fused attention kernel is supported
""" """
return FusedAttnHelper(q_type, kv_type, qkv_layout.value, attn_bias_type.value, return FusedAttnHelper(q_dtype, kv_dtype, qkv_layout.value, attn_bias_type.value,
attn_mask_type.value, dropout_probability, num_heads_q, num_heads_kv, attn_mask_type.value, dropout_probability, q_num_heads, kv_num_heads,
max_seqlen_q, max_seqlen_kv, head_dim).is_fused_attn_kernel_available() q_max_seqlen, kv_max_seqlen, head_dim).is_fused_attn_kernel_available()
def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray, def fused_attn_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
seed: jnp.ndarray | None, attn_bias_type: AttnBiasType, seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float, attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool): dropout_probability: float, is_training: bool):
""" """
Self fused attention wrapper Fused attention with the qkvpacked inputs
""" """
output = _self_fused_attn(qkv, output = _fused_attn_qkvpacked(qkv,
bias, bias,
mask, mask,
seed, seed,
...@@ -87,29 +101,30 @@ def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarra ...@@ -87,29 +101,30 @@ def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarra
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8)) @partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray, def _fused_attn_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
seed: jnp.ndarray | None, attn_bias_type: AttnBiasType, seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float, attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool): dropout_probability: float, is_training: bool):
output, _ = _self_fused_attn_fwd_rule(qkv, bias, mask, seed, attn_bias_type, attn_mask_type, output, _ = _fused_attn_fwd_qkvpacked_rule(qkv, bias, mask, seed, attn_bias_type,
scaling_factor, dropout_probability, is_training) attn_mask_type, scaling_factor, dropout_probability,
is_training)
return output return output
def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None, def _fused_attn_fwd_qkvpacked_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
mask: jnp.ndarray, seed: jnp.ndarray | None, seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, scaling_factor: float,
attn_mask_type: AttnMaskType, dropout_probability: float, is_training: bool):
scaling_factor: float, dropout_probability: float, if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
is_training: bool):
if mask is None:
batch, seqlen, *_ = qkv.shape batch, seqlen, *_ = qkv.shape
actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32) actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32)
else: else:
assert mask is not None
mask = jnp.logical_not(mask) mask = jnp.logical_not(mask)
actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
output, softmax_aux, rng_state = self_fused_attn_fwd(qkv, output, softmax_aux, rng_state = fused_attn_fwd_qkvpacked(
qkv,
bias, bias,
actual_seqlen, actual_seqlen,
seed, seed,
...@@ -124,11 +139,11 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None, ...@@ -124,11 +139,11 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None,
return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen) return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen)
def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, def _fused_attn_bwd_qkvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor,
is_training, ctx, dz): dropout_probability, is_training, ctx, dz):
qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx
grad_qkv, grad_bias = self_fused_attn_bwd(qkv, grad_qkv, grad_bias = fused_attn_bwd_qkvpacked(qkv,
bias, bias,
softmax_aux, softmax_aux,
rng_state, rng_state,
...@@ -147,17 +162,18 @@ def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dr ...@@ -147,17 +162,18 @@ def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dr
return grad_qkv, grad_bias, None, None return grad_qkv, grad_bias, None, None
_self_fused_attn.defvjp(_self_fused_attn_fwd_rule, _self_fused_attn_bwd_rule) _fused_attn_qkvpacked.defvjp(_fused_attn_fwd_qkvpacked_rule, _fused_attn_bwd_qkvpacked_rule)
def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, def fused_attn_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, seed: jnp.ndarray, attn_bias_type: AttnBiasType,
scaling_factor: float, dropout_probability: float, is_training: bool): attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
""" """
Cross multi-head attention wrapper Fused attention with the kvpacked inputs
""" """
output = _cross_fused_attn(q, output = _fused_attn_kvpacked(q,
kv, kv,
bias, bias,
mask, mask,
...@@ -172,32 +188,36 @@ def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: j ...@@ -172,32 +188,36 @@ def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: j
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) @partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9))
def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, def _fused_attn_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, seed: jnp.ndarray, attn_bias_type: AttnBiasType,
scaling_factor: float, dropout_probability: float, is_training: bool): attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
output, _ = _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type, output, _ = _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type,
scaling_factor, dropout_probability, is_training) attn_mask_type, scaling_factor, dropout_probability,
is_training)
return output return output
def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type, def _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training): scaling_factor, dropout_probability, is_training):
if mask is None: if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
batch, s_q, *_ = q.shape batch, s_q, *_ = q.shape
s_kv = kv.shape[1] s_kv = kv.shape[1]
q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32) q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32) kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
else: else:
assert mask is not None
mask = jnp.logical_not(mask) mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: if attn_mask_type == AttnMaskType.PADDING_MASK:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,) kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else: else:
# When mask is causal, the actual seqlen is not the last row, use max to find it # When mask is causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2)) kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = cross_fused_attn_fwd(q, output, softmax_aux, rng_state = fused_attn_fwd_kvpacked(
q,
kv, kv,
bias, bias,
q_actual_seqlen, q_actual_seqlen,
...@@ -214,11 +234,11 @@ def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mas ...@@ -214,11 +234,11 @@ def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mas
return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen) return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen)
def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, def _fused_attn_bwd_kvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor,
is_training, ctx, dz): dropout_probability, is_training, ctx, dz):
q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
grad_q, grad_kv, grad_bias = cross_fused_attn_bwd(q, grad_q, grad_kv, grad_bias = fused_attn_bwd_kvpacked(q,
kv, kv,
bias, bias,
softmax_aux, softmax_aux,
...@@ -239,7 +259,7 @@ def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, d ...@@ -239,7 +259,7 @@ def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, d
return grad_q, grad_kv, grad_bias, None, None return grad_q, grad_kv, grad_bias, None, None
_cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule) _fused_attn_kvpacked.defvjp(_fused_attn_fwd_kvpacked_rule, _fused_attn_bwd_kvpacked_rule)
def fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, def fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
...@@ -277,15 +297,16 @@ def _fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarra ...@@ -277,15 +297,16 @@ def _fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarra
def _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor, def _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training): dropout_probability, is_training):
if mask is None: if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
batch, s_q, *_ = q.shape batch, s_q, *_ = q.shape
s_kv = k.shape[1] s_kv = k.shape[1]
q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32) q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32) kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
else: else:
assert mask is not None
mask = jnp.logical_not(mask) mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: if attn_mask_type == AttnMaskType.PADDING_MASK:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,) kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else: else:
# When mask is causal, the actual seqlen is not the last row, use max to find it # When mask is causal, the actual seqlen is not the last row, use max to find it
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment