Unverified Commit 8e672ff0 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Refactor fused attention (#711)



* Remove unused headers
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Unify the fused attn workspace size cpp code
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Reduce the skipped cases
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Rename self/cross attention to qkvpacked/kvpacked
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update attention mask docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine the attn mask implementations
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent b855656b
......@@ -16,7 +16,7 @@ from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
DTYPES = [jnp.float16, jnp.bfloat16]
......@@ -86,7 +86,7 @@ class TestDistributedSelfAttn:
def target_func(qkv, bias, mask):
return jnp.mean(
self_fused_attn(qkv,
fused_attn_qkvpacked(qkv,
bias,
mask,
None,
......@@ -192,7 +192,7 @@ class TestDistributedCrossAttn:
def target_func(q, kv, mask):
return jnp.mean(
cross_fused_attn(q,
fused_attn_kvpacked(q,
kv,
None,
mask,
......
......@@ -2,8 +2,6 @@
#
# See LICENSE for license information.
"""Tests for fused attention"""
import sys
from enum import Enum
from dataclasses import dataclass
from functools import partial
......@@ -21,7 +19,7 @@ from jax import value_and_grad, jit
from jax.typing import ArrayLike, DTypeLike
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn, fused_attn
from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine_jax import NVTE_Fused_Attn_Backend
......@@ -144,11 +142,11 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
case QKVLayout.BS3HD:
query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
qkv = jnp.concatenate((query, key, value), axis=-3)
return self_fused_attn(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
return fused_attn_qkvpacked(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
case QKVLayout.BSHD_BS2HD:
key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
kv = jnp.concatenate((key, value), axis=-3)
return cross_fused_attn(query, kv, bias, mask, dropout_rng,
return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng,
**kwargs).astype(query.dtype)
case QKVLayout.BSHD_BSHD_BSHD:
return fused_attn(query, key, value, bias, mask, dropout_rng,
......@@ -156,6 +154,10 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
class BiasShape(Enum):
"""
Enum class to represent the different bias shapes used in the fused attention.
"""
BIAS_1HSS = '1HSS'
BIAS_B1SS = 'B1SS'
BIAS_BHSS = 'BHSS'
......@@ -188,17 +190,16 @@ class FusedAttnRunner:
if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")
self.backend = FusedAttnHelper(
self.dtype, self.dtype, self.qkv_layout.value, self.attn_bias_type.value,
self.attn_mask_type.value, self.dropout_prob, self.num_heads_q, self.num_heads_kv,
self.max_seqlen_q, self.max_seqlen_kv, self.head_dim).get_fused_attn_backend()
self.backend = FusedAttnHelper(self.dtype, self.dtype, self.qkv_layout.value,
self.attn_bias_type.value, self.attn_mask_type.value,
self.dropout_prob, self.num_heads_q, self.num_heads_kv,
self.max_seqlen_q, self.max_seqlen_kv,
self.head_dim).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("Unsupported inputs combination or device compute capability.")
if self.bias_shape != BiasShape.BIAS_1HSS:
if self.attn_bias_type != AttnBiasType.POST_SCALE_BIAS:
pytest.skip("B1SS, BHSS and 11SS bias shapes require POST_SCALE_BIAS.")
elif self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
"AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.")
elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
......@@ -213,7 +214,9 @@ class FusedAttnRunner:
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim)
if self.bias_shape == BiasShape.BIAS_1HSS:
if self.attn_bias_type == AttnBiasType.NO_BIAS:
bias_shape = None
elif self.bias_shape == BiasShape.BIAS_1HSS:
bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape.BIAS_B1SS:
bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
......@@ -222,7 +225,7 @@ class FusedAttnRunner:
elif self.bias_shape == BiasShape.BIAS_11SS:
bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
else:
pytest.xfail("PyTest attempted to use an unrecognized bias layout!")
pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.)
self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.)
......@@ -327,8 +330,8 @@ class FusedAttnRunner:
**kwargs), arg_nums))
jitted_reference = jit(
value_and_grad(
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args,
**kwargs), arg_nums))
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
arg_nums))
primitive_out, primitive_dgrad = jitted_primitive(*args)
reference_out, reference_dgrad = jitted_reference(*args)
......@@ -361,9 +364,9 @@ class FusedAttnRunner:
primitive_dbias = jnp.float32(primitive_dgrad[3])
reference_dbias = jnp.float32(reference_dgrad[3])
assert_allclose(
primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:]),
assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
jnp.zeros_like(primitive_dbias[..., self.valid_len_q:,
self.valid_len_kv:]),
dtype=self.dtype)
# dbias padded part
......@@ -376,15 +379,13 @@ class FusedAttnRunner:
reference_dbias[..., :self.valid_len_q, :self.valid_len_kv],
dtype=self.dtype)
@pytest.mark.parametrize('bias_shape', [
pytest.param(BiasShape.BIAS_1HSS, id='1-H-S-S'),
pytest.param(BiasShape.BIAS_B1SS, id='B-1-S-S'),
pytest.param(BiasShape.BIAS_BHSS, id='B-H-S-S'),
pytest.param(BiasShape.BIAS_11SS, id='1-1-S-S'),
])
@pytest.mark.parametrize('attn_bias_type', [
pytest.param(AttnBiasType.NO_BIAS, id='NO_BIAS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, id='POST_SCALE_BIAS'),
@pytest.mark.parametrize('attn_bias_type, bias_shape', [
pytest.param(AttnBiasType.NO_BIAS, None, id='NO_BIAS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id='POST_SCALE_BIAS-1HSS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id='POST_SCALE_BIAS-B1SS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id='POST_SCALE_BIAS-BHSS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id='POST_SCALE_BIAS-11SS'),
])
@pytest.mark.parametrize('attn_mask_type', [
pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'),
......@@ -399,31 +400,32 @@ class FusedAttnRunner:
])
@pytest.mark.parametrize('dtype', [
pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16")
pytest.param(jnp.float16, id="FP16"),
])
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d',[
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d', [
pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'),
pytest.param( 4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'),
pytest.param(4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'),
pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'),
pytest.param( 4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'),
pytest.param(4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'),
pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'),
pytest.param( 4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA')
pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA'),
])
@pytest.mark.parametrize('dropout_prob', [
pytest.param(0.0, id="DROP_0.0"),
pytest.param(0.1, id="DROP_0.1")
])
@pytest.mark.parametrize('is_training', [
pytest.param(True, id='TRAINING'),
pytest.param(False, id='INFERENCE'),
pytest.param(0.1, id="DROP_0.1"),
])
class TestFusedAttn:
"""
Fused attention tester
"""
@staticmethod
def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, is_training, qkv_layout, bias_shape):
@pytest.mark.parametrize('is_training', [
pytest.param(True, id='TRAINING'),
pytest.param(False, id='INFERENCE'),
])
def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
dtype, is_training, qkv_layout, bias_shape):
"""
Test forward with parameterized configs
"""
......@@ -432,13 +434,11 @@ class TestFusedAttn:
runner.test_forward()
@staticmethod
def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, is_training, qkv_layout, bias_shape):
def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
dtype, qkv_layout, bias_shape):
"""
Test backward with parameterized configs
"""
if not is_training:
pytest.skip("Backward pass does not support inference.")
runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, True, qkv_layout, bias_shape)
runner.test_backward()
......@@ -449,6 +449,7 @@ class TestDecoderLayer:
hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER,
self_attn_mask_type='padding_causal',
dtype=dtype,
**te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
......@@ -497,6 +498,7 @@ class TestDecoderLayer:
hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER,
self_attn_mask_type='padding_causal',
dtype=dtype,
**te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
......
......@@ -730,8 +730,13 @@ class TestDotProductAttn(TestLayer):
def input_getter(self, shape, dtype):
key = jax.random.PRNGKey(seed=1234)
q_key, k_key, v_key = jax.random.split(key, 3)
return list(map(partial(jax.random.normal, shape=shape, dtype=dtype),
[q_key, k_key, v_key]))
b, s, *_ = shape
if self.attrs[DotProductAttnAttr.TRANSPOSE_BS]:
b, s = s, b
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), mask
]
def get_layer_name(self):
return 'dot_product_attn'
......@@ -765,6 +770,7 @@ class TestDotProductAttn(TestLayer):
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
......@@ -853,9 +859,11 @@ class MultiHeadAttnAttr:
class TestMultiHeadAttn(TestLayer):
def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape,
dtype), jax.random.normal(data_key, shape, dtype))
key = jax.random.PRNGKey(seed=1234)
q_key, kv_key = jax.random.split(key, 2)
s, b, *_ = shape
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask]
def get_layer_name(self):
return 'multi_head_attn'
......@@ -1183,9 +1191,15 @@ class TransformerLayerAttr:
class TestTransformer(TestLayer):
def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape,
dtype), jax.random.normal(data_key, shape, dtype))
key = jax.random.PRNGKey(seed=1234)
q_key, kv_key = jax.random.split(key, 2)
b, s, *_ = shape
if self.attrs[TransformerLayerAttr.TRANSPOSE_BS]:
b, s = s, b
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask, mask
]
def get_layer_name(self):
return 'transformerlayer'
......@@ -1277,6 +1291,7 @@ class TestTransformer(TestLayer):
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
......@@ -1292,7 +1307,7 @@ class TestTransformer(TestLayer):
fp8_format,
rtol=1e-05,
atol=1e-08):
self.attrs = attrs
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
......
This diff is collapsed.
......@@ -49,10 +49,6 @@ pybind11::dict Registrations() {
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward);
dict["te_scaled_upper_triang_masked_softmax_backward"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward);
dict["te_self_fused_attn_forward"] = EncapsulateFunction(SelfFusedAttnForward);
dict["te_self_fused_attn_backward"] = EncapsulateFunction(SelfFusedAttnBackward);
dict["te_cross_fused_attn_forward"] = EncapsulateFunction(CrossFusedAttnForward);
dict["te_cross_fused_attn_backward"] = EncapsulateFunction(CrossFusedAttnBackward);
dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward);
dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward);
return dict;
......@@ -72,10 +68,6 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes);
m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes);
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
m.def("get_self_fused_attn_fwd_workspace_sizes", &GetSelfFusedAttnForwardWorkspaceSizes);
m.def("get_self_fused_attn_bwd_workspace_sizes", &GetSelfFusedAttnBackwardWorkspaceSizes);
m.def("get_cross_fused_attn_fwd_workspace_sizes", &GetCrossFusedAttnForwardWorkspaceSizes);
m.def("get_cross_fused_attn_bwd_workspace_sizes", &GetCrossFusedAttnBackwardWorkspaceSizes);
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
......
This diff is collapsed.
......@@ -118,17 +118,18 @@ struct CustomCallFusedAttnDescriptor {
float dropout_probability;
NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type;
NVTE_QKV_Layout qkv_layout;
DType dtype;
DType wkspace_dtype;
bool is_training;
};
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t wkspace_size, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
DType dtype, DType wkspace_dtype, bool is_training);
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......@@ -207,55 +208,19 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t max_seqlen,
size_t attn_heads, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t max_seqlen,
size_t attn_heads, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training);
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training);
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
......@@ -26,7 +26,7 @@ from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from ..fused_attn import is_fused_attn_kernel_available, canonicalize_attn_mask_type
from ..fused_attn import self_fused_attn, cross_fused_attn, fused_attn
from ..fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn
from ..softmax import SoftmaxType
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
......@@ -190,16 +190,19 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-publi
def convert_to_softmax_type(attn_mask_type, mask):
"""Convert the attn_mask_type to SoftmaxType"""
# mask is ignored for no_mask and causal_mask
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
mask = None
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]:
if mask is not None:
return SoftmaxType.SCALED_MASKED
return SoftmaxType.SCALED
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type = {'causal', 'padding'}")
return SoftmaxType.SCALED_MASKED, mask
return SoftmaxType.SCALED, mask
raise ValueError(f"Unsupported {attn_mask_type=}, supported attn_mask_type="
"{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}")
softmax_type = convert_to_softmax_type(self.attn_mask_type, mask)
softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask)
attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor)(attn_weights, mask,
......@@ -266,7 +269,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
qkv_packed = query
if self.transpose_batch_sequence:
qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
x = self_fused_attn(qkv_packed,
x = fused_attn_qkvpacked(qkv_packed,
bias,
mask,
seed,
......@@ -285,7 +288,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
if self.transpose_batch_sequence:
query = query.transpose([1, 0, 2, 3])
kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
x = cross_fused_attn(query,
x = fused_attn_kvpacked(query,
kv_packed,
bias,
mask,
......@@ -358,11 +361,27 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method
attention_dropout: float, default = 0.0
Dropout probability for the dropout op after the softmax.
attn_mask_type: str, default = 'causal'
Type of the attention mask passed into softmax operation in the self attention.
Available options: {'no_mask', 'padding', 'causal', 'causal_padding'}
Introduced in v0.10.0.
This parameter specifies the type of attention mask to be applied during the softmax
operation.
Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}
Each described below:
* no_mask: No attention mask is applied. This means the attention will consider the
full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
:attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.
.. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.
attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the self attention.
Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
When default is present, the type is automatically decided by the MHA's bias parameter.
Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used.
......@@ -438,6 +457,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method
mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out the attention softmax input.
:attr:`True` means to mask out the corresponding values.
Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
bias: jax.numpy.ndarray, default = None
A tensor used to shift attention softmax input.
*:
......@@ -639,9 +659,25 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attention_dropout: float, default = 0.0
Dropout probability for the dropout op after the softmax.
attn_mask_type: str, default = 'causal'
Type of the attention mask passed into softmax operation in the attention.
Available options: {'no_mask', 'padding', 'causal', 'causal_padding'}
Introduced in v0.10.0.
This parameter specifies the type of attention mask to be applied during the softmax
operation.
Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}
Each described below:
* no_mask: No attention mask is applied. This means the attention will consider the
full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
:attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.
.. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.
attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
......@@ -809,6 +845,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out the attention softmax input.
:attr:`True` means mask out the corresponding values.
Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
bias: jax.numpy.ndarray, default = None
A tensor used to shift the attention softmax input.
*
......@@ -1299,9 +1336,25 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
is added after self-attention.this can be used for structures like `T5`
Transformer in conjunction with the TransformerLayerType.ENCODER option.
self_attn_mask_type: str, default = 'causal'
Type of the attention mask passed into softmax operation in the self attention.
Available options: {'no_mask', 'padding', 'causal', 'causal_padding'}
Introduced in v0.10.0.
This parameter specifies the type of attention mask to be applied during the softmax
operation in the self attention.
Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}
Each described below:
* no_mask: No attention mask is applied. This means the self attention will consider the
full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
:attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.
.. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.
self_attn_bias_type: Optional[str], default = None
Type of the attention bias passed into the self attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
......@@ -1420,9 +1473,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
:attr:`layer_type=TransformerLayerType.DECODER`.
attention_mask : jax.numpy.ndarray, default = None
Boolean tensor used to mask out self-attention softmax input.
:attr:`True` means mask out the corresponding values.
Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'.
encoder_decoder_mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out cross-attention softmax input when
:attr:`layer_type=TransformerLayerType.DECODER`.
:attr:`True` means mask out the corresponding values.
deterministic: bool, default = False
Disable dropout layers if set to True.
decode: bool, default = False
......
......@@ -14,20 +14,29 @@ from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from .cpp_extensions import FusedAttnHelper
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
from .cpp_extensions import fused_attn_fwd_kvpacked, fused_attn_bwd_kvpacked
from .cpp_extensions import fused_attn_fwd_qkvpacked, fused_attn_bwd_qkvpacked
from .cpp_extensions import fused_attn_fwd, fused_attn_bwd
class AttnBiasType(Enum):
"""Attention Bias Type."""
"""
NO_BIAS: Softmax is performed as softmax(scale * qk)
PRE_SCALE_BIAS: Softmax is performed as softmax(scale * (qk + bias))
POST_SCALE_BIAS: Softmax is performed as softmax(scale * qk + bias)
"""
NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS
class AttnMaskType(Enum):
"""Attention Mask Type."""
"""
NO_MASK: No attention mask is applied.
PADDING_MASK: Indicates the presence of paddings at the end of each sequence.
CAUSAL_MASK: An upper triangular mask is applied to the softmax inputs.
PADDING_CAUSAL_MASK: A combination of both causal and padding masks.
"""
NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
......@@ -47,33 +56,38 @@ def canonicalize_attn_mask_type(attn_mask_type: str):
The overhead between padding and non-padding version should be small.
However, we will lease this limitation in the near feature.
"""
if attn_mask_type in ['causal', 'padding_causal']:
return AttnMaskType.PADDING_CAUSAL_MASK
if attn_mask_type in ['no_mask', 'padding']:
match attn_mask_type:
case 'no_mask':
return AttnMaskType.NO_MASK
case 'padding':
return AttnMaskType.PADDING_MASK
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type={'no_mask', 'padding', 'causal', 'padding_causal'}")
case 'causal':
return AttnMaskType.CAUSAL_MASK
case 'padding_causal' | 'causal_padding':
return AttnMaskType.PADDING_CAUSAL_MASK
raise ValueError(f"Unsupported {attn_mask_type=}, supported attn_mask_type="
"{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}")
def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type,
dropout_probability, num_heads_q, num_heads_kv, max_seqlen_q,
max_seqlen_kv, head_dim):
def is_fused_attn_kernel_available(q_dtype, kv_dtype, qkv_layout, attn_bias_type, attn_mask_type,
dropout_probability, q_num_heads, kv_num_heads, q_max_seqlen,
kv_max_seqlen, head_dim):
"""
To check whether the fused attention kernel is available
To check whether the fused attention kernel is supported
"""
return FusedAttnHelper(q_type, kv_type, qkv_layout.value, attn_bias_type.value,
attn_mask_type.value, dropout_probability, num_heads_q, num_heads_kv,
max_seqlen_q, max_seqlen_kv, head_dim).is_fused_attn_kernel_available()
return FusedAttnHelper(q_dtype, kv_dtype, qkv_layout.value, attn_bias_type.value,
attn_mask_type.value, dropout_probability, q_num_heads, kv_num_heads,
q_max_seqlen, kv_max_seqlen, head_dim).is_fused_attn_kernel_available()
def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
def fused_attn_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
Self fused attention wrapper
Fused attention with the qkvpacked inputs
"""
output = _self_fused_attn(qkv,
output = _fused_attn_qkvpacked(qkv,
bias,
mask,
seed,
......@@ -87,29 +101,30 @@ def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarra
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
def _fused_attn_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
output, _ = _self_fused_attn_fwd_rule(qkv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training)
output, _ = _fused_attn_fwd_qkvpacked_rule(qkv, bias, mask, seed, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability,
is_training)
return output
def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None,
mask: jnp.ndarray, seed: jnp.ndarray | None,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float,
is_training: bool):
if mask is None:
def _fused_attn_fwd_qkvpacked_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
batch, seqlen, *_ = qkv.shape
actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32)
else:
assert mask is not None
mask = jnp.logical_not(mask)
actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
output, softmax_aux, rng_state = fused_attn_fwd_qkvpacked(
qkv,
bias,
actual_seqlen,
seed,
......@@ -124,11 +139,11 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None,
return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen)
def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz):
def _fused_attn_bwd_qkvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, ctx, dz):
qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx
grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
grad_qkv, grad_bias = fused_attn_bwd_qkvpacked(qkv,
bias,
softmax_aux,
rng_state,
......@@ -147,17 +162,18 @@ def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dr
return grad_qkv, grad_bias, None, None
_self_fused_attn.defvjp(_self_fused_attn_fwd_rule, _self_fused_attn_bwd_rule)
_fused_attn_qkvpacked.defvjp(_fused_attn_fwd_qkvpacked_rule, _fused_attn_bwd_qkvpacked_rule)
def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
def fused_attn_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
Cross multi-head attention wrapper
Fused attention with the kvpacked inputs
"""
output = _cross_fused_attn(q,
output = _fused_attn_kvpacked(q,
kv,
bias,
mask,
......@@ -172,32 +188,36 @@ def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: j
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9))
def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
def _fused_attn_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
output, _ = _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training)
output, _ = _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability,
is_training)
return output
def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
def _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
if mask is None:
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
batch, s_q, *_ = q.shape
s_kv = kv.shape[1]
q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
else:
assert mask is not None
mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
if attn_mask_type == AttnMaskType.PADDING_MASK:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else:
# When mask is causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = cross_fused_attn_fwd(q,
output, softmax_aux, rng_state = fused_attn_fwd_kvpacked(
q,
kv,
bias,
q_actual_seqlen,
......@@ -214,11 +234,11 @@ def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mas
return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen)
def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz):
def _fused_attn_bwd_kvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, ctx, dz):
q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
grad_q, grad_kv, grad_bias = cross_fused_attn_bwd(q,
grad_q, grad_kv, grad_bias = fused_attn_bwd_kvpacked(q,
kv,
bias,
softmax_aux,
......@@ -239,7 +259,7 @@ def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, d
return grad_q, grad_kv, grad_bias, None, None
_cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule)
_fused_attn_kvpacked.defvjp(_fused_attn_fwd_kvpacked_rule, _fused_attn_bwd_kvpacked_rule)
def fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
......@@ -277,15 +297,16 @@ def _fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarra
def _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
if mask is None:
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
batch, s_q, *_ = q.shape
s_kv = k.shape[1]
q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
else:
assert mask is not None
mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
if attn_mask_type == AttnMaskType.PADDING_MASK:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else:
# When mask is causal, the actual seqlen is not the last row, use max to find it
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment