Unverified Commit 9b2fed51 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Refine MHA API and add DPA API (#653)



* Refine MHA API
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Reuse func from the flax
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* DPA draft
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* qkv packed draft
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix test_layer with fused attn
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add attn_bias_type and enhance a few code flow
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Move scale_factor from __call__ to init
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance the docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add DPA public API and tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix conflict
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add qkv separate fused attn
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Apply BSHD_BSHD_BSHD format
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove debug log
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add fused attention layer tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add NVTE_FUSED_ATTN docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fine-grained fused attn settings
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove the default value of num_attetnion_head and head_dim
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add teardown for fused attn env
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Unify the Optional notation
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix Pre/Post scale bias comments
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add no_mask tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add checkpoint_name for fused attn
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix the fused attn batcher
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent fb2f952a
......@@ -45,6 +45,9 @@ Modules
.. autoapiclass:: transformer_engine.jax.flax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.flax.DotProductAttention(head_dim, num_heads, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.flax.MultiHeadAttention(head_dim, num_heads, **kwargs)
:members: __call__
......
......@@ -20,7 +20,7 @@ from jax import value_and_grad, jit
from jax.typing import ArrayLike, DTypeLike
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn, fused_attn
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
......@@ -144,6 +144,9 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
kv = jnp.concatenate((key, value), axis=-3)
return cross_fused_attn(query, kv, bias, mask, dropout_rng,
**kwargs).astype(query.dtype)
case QKVLayout.BSHD_BSHD_BSHD:
return fused_attn(query, key, value, bias, mask, dropout_rng,
**kwargs).astype(query.dtype)
@dataclass
......@@ -337,6 +340,7 @@ class FusedAttnRunner:
@pytest.mark.parametrize('qkv_layout', [
pytest.param(QKVLayout.BS3HD, id='qkvpacked'),
pytest.param(QKVLayout.BSHD_BS2HD, id='kvpacked'),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id='separate'),
])
@pytest.mark.parametrize('dropout_prob', [0., 0.1])
@pytest.mark.parametrize('is_training',
......
......@@ -2,6 +2,7 @@
#
# See LICENSE for license information.
import os
from functools import partial
import flax
......@@ -20,6 +21,16 @@ from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
is_fp8_supported, reason = is_fp8_available()
@pytest.fixture(autouse=True, scope='module')
def enable_fused_attn():
"""
Enable fused attention
"""
os.environ["NVTE_FUSED_ATTN"] = "1"
yield
del os.environ["NVTE_FUSED_ATTN"]
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
......@@ -93,6 +104,7 @@ _KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb"
BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_DROPOUT_RATE: 0,
}
ATTRS = [{
......@@ -221,6 +233,7 @@ class TestEncoderLayer:
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
del data_rng, init_rng, apply_rng
......@@ -282,9 +295,6 @@ class TestEncoderLayer:
test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer,
apply_rng)
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
def reorganize_test_wgrad(test_wgrad, attrs):
num_heads = attrs.get(_KEY_OF_NUM_HEADS)
num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
......@@ -328,6 +338,10 @@ class TestEncoderLayer:
del unfreeze_test_wgrad['mlp']['wo_kernel']
return unfreeze_test_wgrad
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
compare_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol,
......@@ -430,6 +444,7 @@ class TestDecoderLayer:
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
del data_rng, init_rng, apply_rng
......@@ -492,9 +507,6 @@ class TestDecoderLayer:
test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer,
apply_rng)
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
def reorganize_test_wgrad(test_wgrad, attrs):
num_heads = attrs.get(_KEY_OF_NUM_HEADS)
num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
......@@ -547,6 +559,9 @@ class TestDecoderLayer:
del unfreeze_test_wgrad['mlp']['wo_kernel']
return unfreeze_test_wgrad
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
compare_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol,
......
......@@ -2,6 +2,7 @@
#
# See LICENSE for license information.
import os
from functools import partial
from typing import Dict
......@@ -14,12 +15,14 @@ import pytest
from utils import assert_allclose
from transformer_engine_jax import get_device_compute_capability
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax import fp8_autocast, update_fp8_metas, update_collections
from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral
from transformer_engine.jax.flax import LayerNorm as flax_LayerNorm
from transformer_engine.jax.flax import LayerNormMLP as flax_LayerNormMLP
from transformer_engine.jax.flax import MultiHeadAttention as flax_MultiHeadAttention
from transformer_engine.jax.flax import DotProductAttention as flax_DotProductAttention
from transformer_engine.jax.flax import RelativePositionBiases as flax_RelativePositionBiases
from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLayer
from transformer_engine.jax.flax.module import Softmax
......@@ -27,8 +30,8 @@ from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from transformer_engine.jax.praxis import LayerNorm
from transformer_engine.jax.praxis import FusedSoftmax
from transformer_engine.jax.praxis import LayerNormLinear, LayerNormMLP, Linear
from transformer_engine.jax.praxis import MultiHeadAttention, RelativePositionBiases
from transformer_engine.jax.praxis import TransformerEngineBaseLayer
from transformer_engine.jax.praxis import DotProductAttention, MultiHeadAttention
from transformer_engine.jax.praxis import RelativePositionBiases, TransformerEngineBaseLayer
from transformer_engine.jax.praxis import TransformerLayer, TransformerLayerType
from transformer_engine.jax.softmax import SoftmaxType
......@@ -40,6 +43,19 @@ ENABLE_FP8 = [False, True]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]
@pytest.fixture(autouse=True, scope='module')
def enable_fused_attn():
"""
Enable fused attn for hopper+ arch.
Fused attn kernels on pre-hopper arch are not deterministic.
"""
if get_device_compute_capability(0) >= 90:
os.environ["NVTE_FUSED_ATTN"] = "1"
yield
if "NVTE_FUSED_ATTN" in os.environ:
del os.environ["NVTE_FUSED_ATTN"]
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
......@@ -101,6 +117,7 @@ class TestLayer:
lyr_name = self.get_layer_name()
if 'params' in flax_variables:
synced_praxis_variables['params'][lyr_name]['cld'] = \
flax.core.unfreeze(flax_variables['params'])
......@@ -111,6 +128,7 @@ class TestLayer:
lyr_name = self.get_layer_name()
if 'params' in synced_praxis_grads:
synced_praxis_grads['params'] = \
synced_praxis_grads['params'][lyr_name]['cld']
......@@ -671,6 +689,86 @@ class TestRelativePositionBias(TestLayer):
assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)
class DotProductAttnAttr:
ATTN_MASK_TYPE = 'attn_mask_type'
NUM_GQA_GROUPS = 'num_gqa_groups'
TRANSPOSE_BS = 'transpose_batch_sequence'
SCALE_FACTOR = 'scale_factor'
ATTRS = [{
ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125,
}, {
ATTN_MASK_TYPE: 'padding_causal',
TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125,
}, {
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125,
}, {
ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: False,
SCALE_FACTOR: 0.125,
}, {
ATTN_MASK_TYPE: 'padding_causal',
TRANSPOSE_BS: False,
SCALE_FACTOR: 2.,
}, {
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: False,
SCALE_FACTOR: 1.,
}, {
ATTN_MASK_TYPE: 'no_mask',
TRANSPOSE_BS: False,
SCALE_FACTOR: 1.,
}]
class TestDotProductAttn(TestLayer):
def input_getter(self, shape, dtype):
key = jax.random.PRNGKey(seed=1234)
q_key, k_key, v_key = jax.random.split(key, 3)
return list(map(partial(jax.random.normal, shape=shape, dtype=dtype),
[q_key, k_key, v_key]))
def get_layer_name(self):
return 'dot_product_attn'
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
head_dim = 64
num_attention_heads = 16
num_gqa_groups = num_attention_heads
attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE]
transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS]
praxis_p = pax_fiddle.Config(DotProductAttention,
name='mha',
dtype=dtype,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence)
flax_cls = partial(flax_DotProductAttention,
dtype=dtype,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence)
return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', [(32, 128, 16, 64)])
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class MultiHeadAttnAttr:
USE_BIAS = 'use_bias'
LN_TYPE = 'layernorm_type'
......@@ -730,36 +828,38 @@ class TestMultiHeadAttn(TestLayer):
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
head_dim = 64
num_heads = 16
num_attention_heads = 16
num_gqa_groups = attrs[MultiHeadAttnAttr.NUM_GQA_GROUPS] \
if MultiHeadAttnAttr.NUM_GQA_GROUPS in attrs else None
layernorm_type = attrs[MultiHeadAttnAttr.LN_TYPE]
zero_centered_gamma = attrs[MultiHeadAttnAttr.ZERO_CEN]
kernel_init = WeightInit.Gaussian(1.0)
use_bias = attrs[MultiHeadAttnAttr.USE_BIAS]
bias_init = WeightInit.Constant(0.0)
apply_residual_connection_post_layernorm = False
output_layernorm = False
input_layernorm = False
return_layernorm_output = False
attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE]
fuse_qkv: bool = True
fuse_qkv_params = True
transpose_batch_sequence = True
scale_attn_logits = False
scaled_query_init = True
float32_logits = False
praxis_p = pax_fiddle.Config(
MultiHeadAttention,
praxis_p = pax_fiddle.Config(MultiHeadAttention,
name='mha',
dtype=dtype,
head_dim=head_dim,
num_heads=num_heads,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
params_init=kernel_init,
use_bias=use_bias,
bias_init=bias_init,
apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm,
output_layernorm=output_layernorm,
return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm,
attn_mask_type=attn_mask_type,
fuse_qkv=fuse_qkv,
fuse_qkv_params=fuse_qkv_params,
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init,
......@@ -768,16 +868,17 @@ class TestMultiHeadAttn(TestLayer):
flax_MultiHeadAttention,
dtype=dtype,
head_dim=head_dim,
num_heads=num_heads,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
use_bias=use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm,
output_layernorm=output_layernorm,
return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm,
attn_mask_type=attn_mask_type,
fuse_qkv=fuse_qkv,
fuse_qkv_params=fuse_qkv_params,
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init,
......@@ -1024,6 +1125,7 @@ class TestTransformer(TestLayer):
enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE]
enable_relative_embedding = True
relative_embedding = pax_fiddle.Config(RelativePositionBiases,
dtype=dtype,
num_attention_heads=num_attention_heads)
drop_path = 0.0
transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS]
......
......@@ -934,7 +934,7 @@ class EncoderLayer(nn.Module):
y = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="output_layer_norm")(y)
name="output_layernorm")(y)
return y
......@@ -1090,7 +1090,7 @@ class DecoderLayer(nn.Module):
z = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="output_layer_norm")(z)
name="output_layernorm")(z)
return z
......
......@@ -105,8 +105,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool flag_m512 = false;
bool flag_arb = false;
if ((sm_arch_ == 80 || sm_arch_ == 90)
&& (max_seqlen_q <= 512)
&& (max_seqlen_kv <= 512)
&& (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0)
&& (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0)
&& (head_dim == 64)
&& (num_attn_heads == num_gqa_groups)
&& ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
......
......@@ -1885,6 +1885,7 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type,
attn_mask_type, dropout_probability, num_heads, num_heads,
max_seqlen, max_seqlen, head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_shape = (*batch_shape, num_heads, max_seqlen, max_seqlen)
softmax_dtype = qkv_dtype
......@@ -2029,7 +2030,7 @@ def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, seqlen: jnp.ndarray
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Wrapper for TE self fused attention fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2
"""
checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
......@@ -2273,6 +2274,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
attn_bias_type, attn_mask_type, dropout_probability, num_heads,
num_gqa_groups, q_max_seqlen, kv_max_seqlen,
q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, kv_max_seqlen)
softmax_dtype = q_dtype
......@@ -2426,7 +2428,7 @@ def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, q_s
dropout_probability: float, is_training: bool):
"""
Wrapper for TE cross fused attention fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2
"""
checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
......@@ -2662,6 +2664,445 @@ def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
is_training=is_training)
class FusedAttnFwdPrimitive(BasePrimitive):
"""
Fused Attention Forward Primitive
Query, key, value are seperated tensors
"""
name = "te_fused_attn_forward"
multiple_results = True
impl_static_args = (7, 8, 9, 10, 11)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(q_aval, k_aval, v_aval, bias_aval, q_seqlen_or_cu_seqlen_aval,
kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
"""
Fused attention fwd abstract
"""
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
assert q_dtype == k_dtype == v_dtype == bias_dtype
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, num_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert k_aval.shape == v_aval.shape
out_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
# backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(q_dtype, k_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD, attn_bias_type,
attn_mask_type, dropout_probability, num_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, kv_max_seqlen)
softmax_dtype = q_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f'Unsupported {backend=}')
softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)
# JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
# 32-bit unsigned int to get the buffer size we need in the C++ kernel
checker = _FusedAttnRNGStateChecker()
seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
assert seed_dtype == checker.rng_state_dtype
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
# prepare for the active fused-attn backend
batch_size = reduce(operator.mul, q_batch_shape)
wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
wkspace_aval = q_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Fused attention fwd outer primitive abstract
"""
out_aval, softmax_aux_aval, rng_state_aval, _ = \
FusedAttnFwdPrimitive.abstract(*args, **kwargs)
return out_aval, softmax_aux_aval, rng_state_aval
@staticmethod
def lowering(ctx, q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Fused attention fwd lowering rules
"""
operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, k_aval, v_aval, *_ = ctx.avals_in
*batch_shape, q_max_seqlen, num_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, num_gqa_groups, _ = k_aval.shape
assert k_aval.shape == v_aval.shape
batch_size = reduce(operator.mul, batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim,
wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(q, k, v, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
assert FusedAttnFwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind(
q,
k,
v,
bias,
q_cu_seqlen,
kv_cu_seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output, softmax_aux, rng_state
@staticmethod
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
_check_valid_batch_dims(batch_dims)
assert FusedAttnFwdPrimitive.outer_primitive is not None
q_bdim, *_, seed_bdim = batch_dims
out_bdims = q_bdim, q_bdim, seed_bdim
return FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training), out_bdims
@staticmethod
def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, mesh, arg_infos,
result_infos):
del attn_bias_type, attn_mask_type, scaling_factor
del dropout_probability, is_training, result_infos
q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden)
k_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3]))
rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
return (out_sharding, softmax_aux_sharding, rng_state_sharding)
@staticmethod
def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
mesh, arg_infos, result_infos):
del result_infos
q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden)
k_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3]))
rng_state_sharding = seed_sharding = NamedSharding(mesh,
PartitionSpec(get_all_mesh_axes(), None))
arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
impl = partial(FusedAttnFwdPrimitive.impl,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return mesh, impl, out_shardings, arg_shardings
register_primitive(FusedAttnFwdPrimitive)
def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray,
q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Wrapper for TE fused attention fwd, where query, key, value are seperated tensors
Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2
"""
checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
return FusedAttnFwdPrimitive.outer_primitive.bind(q,
k,
v,
bias,
q_seqlen,
kv_seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
class FusedAttnBwdPrimitive(BasePrimitive):
"""
Fused Attention Backward Primitive
"""
name = "te_fused_attn_backward"
multiple_results = True
impl_static_args = (10, 11, 12, 13, 14)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(q_aval, k_aval, v_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval,
doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Fused attention bwd abstract
"""
del softmax_aux_aval, rng_state_aval, output_aval
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, num_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert k_aval.shape == v_aval.shape
batch_size = reduce(operator.mul, q_batch_shape)
wkspace_shape, wkspace_dtype = \
transformer_engine_jax.get_fused_attn_bwd_workspace_sizes(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training
)
dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype)
dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype)
dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
wkspace_aval = q_aval.update(shape=wkspace_shape,
dtype=te_dtype_to_jax_dtype(wkspace_dtype))
return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Fused attention fwd outer primitive abstract
"""
dq_aval, dk_aval, dv_aval, dbias_aval, _ = \
FusedAttnBwdPrimitive.abstract(*args, **kwargs)
return dq_aval, dk_aval, dv_aval, dbias_aval
@staticmethod
def lowering(ctx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
kv_cu_seqlen, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
"""
Fused attention bwd lowering rules
"""
operands = [
q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, k_aval, v_aval, *_ = ctx.avals_in
*batch_shape, q_max_seqlen, num_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, num_gqa_groups, _ = k_aval.shape
assert k_aval.shape == v_aval.shape
batch_size = reduce(operator.mul, batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim,
wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen,
attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training):
assert FusedAttnBwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return dq, dk, dv, dbias
@staticmethod
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
_check_valid_batch_dims(batch_dims)
assert FusedAttnBwdPrimitive.outer_primitive is not None
q_bdim, k_bdim, v_bdim, *_ = batch_dims
out_bdims = q_bdim, k_bdim, v_bdim, q_bdim
return FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training), out_bdims
@staticmethod
def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, mesh, arg_infos,
result_infos):
del attn_bias_type, attn_mask_type, scaling_factor
del dropout_probability, is_training, result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
@staticmethod
def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
mesh, arg_infos, result_infos):
del result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
def sharded_impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
kv_cu_seqlen):
local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
global_dbias = local_dbias
if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
return local_dq, local_dk, local_dv, global_dbias
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(FusedAttnBwdPrimitive)
def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray,
softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray,
doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Wrapper for TE fused attention bwd
Return the gradients of fused attention with seperated query, key, value tensors
"""
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
return FusedAttnBwdPrimitive.outer_primitive.bind(q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
class GeluPrimitive(BasePrimitive):
"""
Gelu Froward Primitive
......
......@@ -53,6 +53,8 @@ pybind11::dict Registrations() {
dict["te_self_fused_attn_backward"] = EncapsulateFunction(SelfFusedAttnBackward);
dict["te_cross_fused_attn_forward"] = EncapsulateFunction(CrossFusedAttnForward);
dict["te_cross_fused_attn_backward"] = EncapsulateFunction(CrossFusedAttnBackward);
dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward);
dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward);
return dict;
}
......@@ -74,6 +76,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_self_fused_attn_bwd_workspace_sizes", &GetSelfFusedAttnBackwardWorkspaceSizes);
m.def("get_cross_fused_attn_fwd_workspace_sizes", &GetCrossFusedAttnForwardWorkspaceSizes);
m.def("get_cross_fused_attn_bwd_workspace_sizes", &GetCrossFusedAttnBackwardWorkspaceSizes);
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte)
......@@ -98,7 +102,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local())
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD);
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD);
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend)
......
......@@ -1253,7 +1253,6 @@ pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// TODO(rewang): add bias for cross attn?
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
// FP16/BF16 doesn't use this tensor
......@@ -1488,5 +1487,265 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
nvte_tensor_pack_destroy(&aux_input_tensors);
}
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input buffers from XLA
void *q = buffers[0];
void *k = buffers[1];
void *v = buffers[2];
void *bias = buffers[3];
void *q_cu_seqlens = buffers[4];
void *kv_cu_seqlens = buffers[5];
void *seed = buffers[6];
// output buffers from XLA
void *output = buffers[7];
void *softmax_aux = buffers[8];
void *rng_state = buffers[9];
void *workspace = buffers[10];
// tensor sizes
auto batch_size = descriptor.batch_size;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto num_heads = descriptor.num_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto o_tensor = TensorWrapper(output, q_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// prep RNG state
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
// auxiliary tensors (to be propagated to the backward pass later)
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
softmax_aux);
// cuDNN workspace
auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
descriptor.wkspace_dtype);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input buffers from XLA
void *q = buffers[0];
void *k = buffers[1];
void *v = buffers[2];
void *bias = buffers[3];
void *softmax_aux = buffers[4];
void *rng_state = buffers[5];
void *output = buffers[6];
void *doutput = buffers[7];
void *q_cu_seqlens = buffers[8];
void *kv_cu_seqlens = buffers[9];
// output buffers from XLA
void *dq = buffers[10];
void *dk = buffers[11];
void *dv = buffers[12];
void *dbias = buffers[13];
void *workspace = buffers[14];
// tensor sizes
auto batch_size = descriptor.batch_size;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto num_heads = descriptor.num_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dk_tensor = TensorWrapper(dk, k_shape, dtype);
auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// auxiliary tensors (propagated from the forward pass)
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias);
// cuDNN workspace
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_input_tensors);
}
} // namespace jax
} // namespace transformer_engine
......@@ -236,6 +236,20 @@ pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
} // namespace jax
} // namespace transformer_engine
......
......@@ -5,11 +5,19 @@
from .module import DenseGeneral, LayerNorm
from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from .transformer import extend_logical_axis_rules
from .transformer import MultiHeadAttention, RelativePositionBiases
from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases
from .transformer import TransformerLayer, TransformerLayerType
__all__ = [
'DenseGeneral', 'LayerNorm', 'LayerNormDenseGeneral', 'LayerNormMLP',
'TransformerEngineBase', 'extend_logical_axis_rules', 'MultiHeadAttention',
'RelativePositionBiases', 'TransformerLayer', 'TransformerLayerType',
'DenseGeneral',
'LayerNorm',
'LayerNormDenseGeneral',
'LayerNormMLP',
'TransformerEngineBase',
'extend_logical_axis_rules',
'DotProductAttention',
'MultiHeadAttention',
'RelativePositionBiases',
'TransformerLayer',
'TransformerLayerType',
]
......@@ -16,6 +16,7 @@ import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import combine_masks
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
......@@ -24,8 +25,8 @@ from jax.ad_checkpoint import checkpoint_name
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from ..fused_attn import is_fused_attn_kernel_available
from ..fused_attn import self_fused_attn, cross_fused_attn
from ..fused_attn import is_fused_attn_kernel_available, canonicalize_attn_mask_type
from ..fused_attn import self_fused_attn, cross_fused_attn, fused_attn
from ..softmax import SoftmaxType
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
......@@ -71,12 +72,12 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
Parameters
----------
rules : Sequence[Tuple[str, Union[str, None]]]
rules: Sequence[Tuple[str, Union[str, None]]]
the base Flax logical axis rules to extend.
Returns
-------
extended_rules : Sequence[Tuple[str, Union[str, None]]]
extended_rules: Sequence[Tuple[str, Union[str, None]]]
the extended Flax logical axis rules.
"""
rules_map = {}
......@@ -108,60 +109,43 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
return tuple(extended_rules)
def _merge_mask(func, *masks: Optional[Array]):
masks = [m for m in masks if m is not None]
if not masks:
return None
assert all(map(lambda x: x.ndim == masks[0].ndim,
masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
mask, *other_masks = masks
for other_mask in other_masks:
mask = func(mask, other_mask)
return mask
def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
"""Combine attention masks."""
func = jnp.logical_and
return _merge_mask(func, *masks).astype(dtype)
def combine_biases(*masks: Optional[Array]):
"""Combine attention biases."""
def func(a, b):
return a + b
return _merge_mask(func, *masks)
class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attention_dropout: float = 0.
attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
attn_bias_type: Optional[AttnBiasType] = None
dtype: DType = jnp.float32
float32_logits: bool = False
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
def core_attention(query: Array,
@nn.compact
def __call__(self,
query: Array,
key: Array,
value: Array,
scale_factor: float,
transpose_batch_sequence: bool,
softmax_type: SoftmaxType = SoftmaxType.SCALED,
mask: Optional[Array] = None,
bias: Optional[Array] = None,
*,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.,
deterministic: bool = False,
dtype: DType = jnp.float32,
float32_logits: bool = False):
"""Core attention"""
deterministic: bool = False) -> Array:
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
batch_dim = 1 if transpose_batch_sequence else 0
batch_dim = 1 if self.transpose_batch_sequence else 0
assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
'q, k, v batch dims must match.')
sequence_dim = 0 if transpose_batch_sequence else 1
sequence_dim = 0 if self.transpose_batch_sequence else 1
assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.'
assert key.shape[-2] == value.shape[-2], 'k, v num_attention_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.'
if float32_logits:
if self.scale_factor is None:
scale_factor = 1.0 / sqrt(query.shape[-1])
else:
scale_factor = self.scale_factor
del self.scale_factor
if self.float32_logits:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
h_q, h_kv = query.shape[-2], key.shape[-2]
# The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
# Therefore, we have to maintain two code paths.
......@@ -172,7 +156,7 @@ def core_attention(query: Array,
group_size = h_q // h_kv
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
if transpose_batch_sequence:
if self.transpose_batch_sequence:
if is_gqa:
attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
else:
......@@ -193,30 +177,47 @@ def core_attention(query: Array,
attn_weights = with_sharding_constraint_by_logical_axes(
attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))
# When a bias is present, the computation is performed as Softmax(attn_weights * scale + bias).
# When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
# In this case, the scale can not fused into the Softmax module.
if bias is not None:
if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
attn_weights = attn_weights * scale_factor
fused_scale_factor = 1.
else:
# If no bias, the scale can be fused into Softmax module
# If not post_scale_bias, the scale can be fused into Softmax module
fused_scale_factor = scale_factor
if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
attn_weights += bias
def convert_to_softmax_type(attn_mask_type, mask):
"""Convert the attn_mask_type to SoftmaxType"""
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]:
if mask is not None:
return SoftmaxType.SCALED_MASKED
return SoftmaxType.SCALED
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type = {'causal', 'padding'}")
softmax_type = convert_to_softmax_type(self.attn_mask_type, mask)
attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype)
scale_factor=fused_scale_factor)(attn_weights, mask,
bias).astype(self.dtype)
if is_gqa:
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate
if not deterministic and self.attention_dropout > 0.:
keep_prob = 1.0 - self.attention_dropout
dropout_shape = list(attn_weights.shape)
# TODO(rewang): add attention dropout broadcast dimension arguments for users
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
multiplier = (keep.astype(attn_weights.dtype) /
jnp.asarray(keep_prob, dtype=self.dtype))
attn_weights = attn_weights * multiplier
if transpose_batch_sequence:
if self.transpose_batch_sequence:
if is_gqa:
return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value)
......@@ -226,6 +227,320 @@ def core_attention(query: Array,
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attention_dropout: float = 0.
attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
attn_bias_type: Optional[AttnBiasType] = None
dtype: DType = jnp.float32
qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = False
@nn.compact
def __call__(self,
query: Array,
key: Array,
value: Array,
mask: Optional[Array] = None,
bias: Optional[Array] = None,
*,
dropout_rng: Optional[PRNGKey] = None,
deterministic: bool = False) -> Array:
seed = None
if dropout_rng is not None:
seed = jax.random.split(dropout_rng, num_of_devices())
if self.scale_factor is None:
scale_factor = 1.0 / sqrt(query.shape[-1])
else:
scale_factor = self.scale_factor
del self.scale_factor
if self.qkv_layout == QKVLayout.BS3HD:
"""qkvpacked format, treat
query: qkvpacked tensor, shape = [..., 3, h, d]
key: ignore
value: ignore
"""
qkv_packed = query
if self.transpose_batch_sequence:
qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
x = self_fused_attn(qkv_packed,
bias,
mask,
seed,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic)
elif self.qkv_layout == QKVLayout.BSHD_BS2HD:
"""kvpacked format, treat
query: query tensor, shape = [..., h, d]
key: kvpacked tensor, shape = [..., 2, h, d]
value: ignore
"""
kv_packed = key
if self.transpose_batch_sequence:
query = query.transpose([1, 0, 2, 3])
kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
x = cross_fused_attn(query,
kv_packed,
bias,
mask,
seed,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic)
elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
if self.transpose_batch_sequence:
query = query.transpose([1, 0, 2, 3])
key = key.transpose([1, 0, 2, 3])
value = value.transpose([1, 0, 2, 3])
x = fused_attn(query,
key,
value,
bias,
mask,
seed,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic)
else:
raise ValueError(f"Unsupported {self.qkv_layout=}.")
if self.transpose_batch_sequence:
x = x.transpose([1, 0, 2, 3])
return x
class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
r"""
Dot Product Attention (DPA). Allows the model to jointly attend to information from different
representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
.. note::
The DotProductAttention module supports two backends: the unfused and the fused attention
mechanisms. The unfused attention is implemented using JAX native operations, providing
broad compatibility and flexibility. In contrast, the fused attention uses `cuDNN fused
attention
<https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md>`_ for
higher performance and lower memory usage on the supported hardwares.
Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
variable:
* Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention (default).
* Set :attr:`NVTE_FUSED_ATTN=1` for fused attention. If the required cuDNN fused attention
kernel is not available on the system, a warning will be issued, and the module will
automatically fall back to the unfused backend.
Parameters
----------
head_dim: int
The hidden dimension of each attention head.
num_attention_heads: int
The number of attention heads.
num_gqa_groups: int, default = `None`
Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
attention_dropout: float, default = 0.0
Dropout probability for the dropout op after the softmax.
attn_mask_type: str, default = 'causal'
Type of the attention mask passed into softmax operation in the self attention.
Available options: {'no_mask', 'padding', 'causal', 'causal_padding'}
Introduced in v0.10.0.
attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the self attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
When default is present, the type is automatically decided by the MHA's bias parameter.
Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used.
dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that is used
to generate Dropout masks in the core attention.
float32_logits: bool, default = False
Whether to compute attention logits in float32 for the unfused attention backend.
For fused attention backend, the accumulation is always float32 without the perf overhead.
qkv_layout: str, default = 'bshd_bshd_bshd'
Specifies the dimensional layout format for the query, key, and value tensors in __call__().
It indicates how the inputs are processed.
Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd'}. Where
* bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d].
key and value arguments in :attr:`__call__()` are ignored in this layout.
* bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked
tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored.
* bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d].
Explanation of denotations:
* b: batch size
* s: seqeuence length
* h: num_attention_heads or num_gqa_groups
* d: head dimension
scale_factor: Optional[float], default = None
Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal
to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't
need to apply scale on query, which is to set :attr:`scale_factor=1.`.
transpose_batch_sequence: bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
"""
head_dim: int
num_attention_heads: int
num_gqa_groups: Optional[int] = None
attention_dropout: float = 0.
attn_mask_type: AttnMaskType = 'causal'
attn_bias_type: AttnBiasType = None
dtype: DType = jnp.float32
dropout_rng_name: str = 'dropout'
float32_logits: bool = False
qkv_layout: str = 'bshd_bshd_bshd'
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
@nn.compact
def __call__(self,
query: Array,
key: Array,
value: Array,
mask: Optional[Array] = None,
bias: Optional[Array] = None,
*,
deterministic: bool = False) -> Array:
"""
Parameters
----------
query: jax.numpy.ndarray
The details of query tensor representation is described in :attr:`qkv_layout`.
key: jax.numpy.ndarrary
The details of kery tensor representation is described in :attr:`qkv_layout`.
value: jax.numpy.ndarrary
The details of value tensor representation is described in :attr:`qkv_layout`.
mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out the attention softmax input.
:attr:`True` means to mask out the corresponding values.
bias: jax.numpy.ndarray, default = None
A tensor used to shift attention softmax input.
*:
Below parameters are keyword only
deterministic: bool, default = False
Disable dropout layers if set to True.
Returns
-------
outputs: jax.numpy.ndarray
Output tensors.
"""
# For internal API, we use enum to maintain
if self.attn_bias_type is None:
attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
else:
attn_bias_type = AttnBiasType[self.attn_bias_type.upper()]
attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
qkv_layout = QKVLayout[self.qkv_layout.upper()]
del self.attn_bias_type, self.attn_mask_type, self.qkv_layout
if attn_bias_type == AttnBiasType.NO_BIAS:
assert bias is None
else:
assert bias is not None
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))
sequence_dim = 0 if self.transpose_batch_sequence else 1
seqlen_q = query.shape[sequence_dim]
if qkv_layout == QKVLayout.BS3HD:
seqlen_kv = seqlen_q
else:
seqlen_kv = key.shape[sequence_dim]
has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout,
attn_bias_type, attn_mask_type,
self.attention_dropout,
self.num_attention_heads,
self.num_gqa_groups, seqlen_q,
seqlen_kv, self.head_dim)
use_fused_attn = (enable_fused_attn and has_fused_attn_kernel)
if enable_fused_attn and not has_fused_attn_kernel:
warnings.warn("Fused attention is not enabled because there is no available kernel.\n"
"Fall back to the unfused attention.\n"
"Please try to update the cuDNN and TE to the latest version.\n"
f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{self.head_dim=}\n")
dropout_rng = None
if not deterministic and self.attention_dropout > 0.:
dropout_rng = self.make_rng(self.dropout_rng_name)
if self.scale_factor is None:
scale_factor = 1.0 / sqrt(self.head_dim)
else:
scale_factor = self.scale_factor
del self.scale_factor
if not use_fused_attn:
# unfused attention only supports splitted query, key, value
if qkv_layout == QKVLayout.BS3HD:
query, key, value = jnp.split(query, [1, 2], axis=-3)
query, key, value = map(functools.partial(jnp.squeeze, axis=-3),
[query, key, value])
elif qkv_layout == QKVLayout.BSHD_BS2HD:
key, value = jnp.split(key, [1], axis=-3)
key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value])
else:
assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
x = _UnfusedDotProductAttention(attention_dropout=self.attention_dropout,
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
dtype=self.dtype,
float32_logits=self.float32_logits,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence)(
query,
key,
value,
mask,
bias,
dropout_rng=dropout_rng,
deterministic=deterministic)
else:
x = _FusedDotProductAttention(
attention_dropout=self.attention_dropout,
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
dtype=self.dtype,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
qkv_layout=qkv_layout,
)(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
return x
def rotary_pos_emb(x: Array, windows: Tuple[int, int], transpose_batch_sequence: bool):
"""
Rotary Positional Embedding
......@@ -259,43 +574,44 @@ def rotary_pos_emb(x: Array, windows: Tuple[int, int], transpose_batch_sequence:
return jnp.concatenate([part_1, part_2], axis=-1)
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
r"""
Multi-head Attention (MHA), including Query,
Key, Value and Output projection.
.. note::
Argument :attr:`mask` will be ignored when
:attr:`attn_mask_type` is set to `"causal"`.
Parameters
----------
head_dim : int
head_dim: int
The hidden dimension of each attention head.
num_heads : int
The number of attention heads
num_gqa_groups : int, default = `None`
Number of GQA groups. When `None` is present, it is equal to num_heads.
num_attention_heads: int
The number of attention heads.
num_gqa_groups: int, default = `None`
Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
dropout_rate : float, default = 0.0
Dropout probability for the dropout op during multi-head attention.
attention_dropout: float, default = 0.0
Dropout probability for the dropout op after the softmax.
attn_mask_type: str, default = 'causal'
Type of the attention mask passed into softmax operation in the attention.
Available options: {'no_mask', 'padding', 'causal', 'causal_padding'}
Introduced in v0.10.0.
attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
When default is present, the type is automatically decided by the MHA's bias parameter.
Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that is used
to generate Dropout masks in the core attention.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
Indicate the type of layer normalization.
layernorm_epsilon: float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma : bool, default = False
zero_centered_gamma: bool, default = False
If set to `True`, the LayerNorm formula changes to
.. math::
......@@ -305,21 +621,20 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
This parameter is only applicable for 'layernorm'.
kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
Used for initializing the QKV and Output projection weights.
Used for initializing the QKV and output projection weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
use_bias: bool, default = False
Indicate whether or not to enable bias shifting for QKVO projections.
Indicate whether or not to enable bias shifting for QKV and output projections.
If set to False, the layer will not learn additive biases.
bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
apply_residual_connection_post_layernorm : bool, default = False
Indicate if apply residual connection with the output of layer normalization.
output_layernorm : bool, default = False
Indicate if apply a layer normalization at the end of MHA.
attn_mask_type: {'causal', 'padding'}, default = 'causal'
Type of attention mask passed into softmax operation.
Introduced in v0.10.0.
input_layernorm: bool, default = True
If set to False, layer normalization to the input is not applied.
return_layernorm_output: bool, default = False
If set to True, output of layernorm is returned from the forward together with the output
of the linear transformation.
Example use case: residual connection for transformer module is taken post layernorm.
enable_rotary_pos_emb: bool, default = False
Whether to enable rotary position embedding to projected query and key.
rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
......@@ -327,58 +642,101 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
only used when :attr:`enable_rotary_pos_emb=True`
enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot.
num_heads: int, default = None
Deprecated. Please refer `num_attention_heads`.
dropout_rate: float, default = None
Deprecated. Please refer `attention_dropout`.
output_layernorm: bool, default = None
Deprecated. Please refer `input_layernorm`
apply_residual_connection_post_layernorm: bool, default = None
Deprecated. Please refer `return_layernorm_output`.
Optimization parameters
-----------------------
dtype :jax.numpy.dtype, default = jax.numpy.float32
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
fuse_qkv: bool, default = True
fuse_qkv_params: bool, default = True
If set to True, this module exposes a single fused
parameter for query-key-value for self-attention and key-value for
cross-attention.
transpose_batch_sequence : bool, default = True
transpose_batch_sequence: bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
scale_attn_logits: bool, default = False
Indicate whether to scale attention logits.
If set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
If set to True, :math:`\frac{Q}{\sqrt{head\_dim}*K}`,
else :math:`Q*K`
scaled_query_init: bool, default = `True`
Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
float32_logits : bool, default = False
Whether to compute attention logits in float32.
scaled_query_init: bool, default = True
Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}`
float32_logits: bool, default = False
Whether to compute attention logits in float32 for the unfused attention backend.
For fused attention backend, the accumulation is always float32 without the perf overhead.
fuse_qkv: bool, default = None
Deprecated. Please refer `fuse_qkv_params`
"""
head_dim: int
num_heads: int
num_gqa_groups: int | None = None
dropout_rate: float = 0.
num_attention_heads: int
num_gqa_groups: Optional[int] = None
attention_dropout: float = 0.
dropout_rng_name: str = 'dropout'
input_layernorm: bool = True
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
return_layernorm_output: bool = False
zero_centered_gamma: bool = False
kernel_init: Initializer = None
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
attn_mask_type: str = 'causal'
attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
dtype: DType = jnp.float32
fuse_qkv: bool = True
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
float32_logits: bool = False # computes logits in float32 for stability.
float32_logits: bool = False
# Deprecated parameters
num_heads: Optional[int] = None
dropout_rate: Optional[float] = None
output_layernorm: Optional[bool] = None
apply_residual_connection_post_layernorm: Optional[bool] = None
fuse_qkv: Optional[bool] = None
def __post_init__(self):
# Deal with the deprecated parameters
if self.num_heads is not None:
self.num_attention_heads = self.num_heads
warnings.warn(
f"{__class__}.num_heads is deprecated. It will be removed recently. "
f"Please uses {__class__}.num_attention_heads as the new API.", DeprecationWarning)
if self.dropout_rate is not None:
self.attention_dropout = self.dropout_rate
warnings.warn(
f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
f"Please use {__class__}.attention_dropout as the new API.", DeprecationWarning)
if self.apply_residual_connection_post_layernorm is not None:
warnings.warn(
f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
f"It will be removed recently, please use {__class__}.return_layernorm_output.",
DeprecationWarning)
if self.fuse_qkv is not None:
warnings.warn(
f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
f"Please use {__class__}.fuse_qkv_params as the new API.", DeprecationWarning)
assert self.output_layernorm is None, (
f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm.")
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_heads
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact
......@@ -396,23 +754,24 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
Parameters
----------
inputs_q : jax.numpy.ndarray
inputs_q: jax.numpy.ndarray
Input tensor for query projection.
inputs_kv : jax.numpy.ndarray
inputs_kv: jax.numpy.ndarray
Input tensor for key/value projection.
mask : jax.numpy.ndarray, default = None
Boolean tensor used to mask out self-attention softmax input.
bias : jax.numpy.ndarray, default = None
A tensor used to shift self-attention softmax input.
mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out the attention softmax input.
:attr:`True` means mask out the corresponding values.
bias: jax.numpy.ndarray, default = None
A tensor used to shift the attention softmax input.
*
decode : bool,default = False
decode: bool, default = False
Indicate whether to prepare and use an autoregressive cache.
deterministic : bool,default = False
deterministic: bool, default = False
Disable dropout layers if set to True.
Returns
-------
outputs : jax.numpy.ndarray
outputs: jax.numpy.ndarray
Output tensors.
"""
......@@ -450,56 +809,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)
# TODO(rewang): make it configurable for pre_scale_bias
attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
def canonicalize_attn_mask_type(attn_mask_type):
"""
Convert the string to AttnMaskType
"""
if attn_mask_type == 'causal':
return AttnMaskType.PADDING_CAUSAL_MASK
if attn_mask_type == 'padding':
return AttnMaskType.PADDING_MASK
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type = {'causal', 'padding'}")
is_self_attn = (inputs_q is inputs_kv)
is_gqa = (self.num_heads != self.num_gqa_groups)
is_qkvpack = (is_self_attn and not is_gqa)
qkv_layout = QKVLayout.BS3HD if is_self_attn else QKVLayout.BSHD_BS2HD
attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1]
kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1]
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))
has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout,
attn_bias_type, attn_mask_type,
self.dropout_rate, self.num_heads,
self.num_gqa_groups, q_seqlen,
kv_seqlen, self.head_dim)
use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
has_fused_attn_kernel and \
enable_fused_attn
if enable_fused_attn and not use_fused_attn:
reason = ""
if decode:
reason += f"decode=False is required but got {decode}, "
if self.transpose_batch_sequence:
reason += f"transpose_batch_sequence=False is required " \
f"but got {self.transpose_batch_sequence}, "
if not self.fuse_qkv:
reason += f"fuse_qkv=True is required but got {self.fuse_qkv}, "
if not has_fused_attn_kernel:
reason += "no fused attention kernel is available, "
warnings.warn(
f"Fused attention is not enabled. Because " \
f"{reason}fall back to unfused attention.")
def generate_batch_seqlen_logical_axes(is_sharded_seq):
sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim
......@@ -510,24 +819,27 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES
return tuple(axes)
is_self_attn = (inputs_q is inputs_kv)
is_gqa = (self.num_attention_heads != self.num_gqa_groups)
is_qkvpack = (is_self_attn and not is_gqa)
inputs_logical_axes_maybe_sp = (*generate_batch_seqlen_logical_axes(
self.enable_sequence_parallel), HIDDEN_AXES)
inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES)
inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)
residual = inputs_q
if self.fuse_qkv:
if self.fuse_qkv_params:
if is_qkvpack:
qkv_proj, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm,
enable_layernorm=self.input_layernorm,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon,
axis=-1,
features=(3, self.num_heads * self.head_dim),
features=(3, self.num_attention_heads * self.head_dim),
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm,
return_layernorm_output=self.return_layernorm_output,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
......@@ -540,19 +852,17 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
name='qkv',
dtype=self.dtype)(inputs_q)
qkv_proj = checkpoint_name(qkv_proj, 'combined_qkv_proj')
if not use_fused_attn:
query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
qkv_layout = QKVLayout.BS3HD
else:
query, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm,
enable_layernorm=self.input_layernorm,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon,
axis=-1,
features=self.num_heads * self.head_dim,
features=self.num_attention_heads * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=(self.apply_residual_connection_post_layernorm
or is_self_attn),
return_layernorm_output=(self.return_layernorm_output or is_self_attn),
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes=(W_FSDP_AXES, W_TP_AXES),
......@@ -580,8 +890,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
name='kv',
dtype=self.dtype)(inputs_kv)
kv_proj = checkpoint_name(kv_proj, 'combined_kv_proj')
if not use_fused_attn:
key, value = jnp.split(kv_proj, [1], axis=-2)
qkv_layout = QKVLayout.BSHD_BS2HD
else:
kv_projection = functools.partial(
DenseGeneral,
......@@ -594,12 +903,12 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_TP_AXES,),
dtype=self.dtype)
query, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm,
enable_layernorm=self.input_layernorm,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon,
axis=-1,
features=self.num_heads * self.head_dim,
features=self.num_attention_heads * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=True,
scale_axes=(W_NO_SHARD_AXES,),
......@@ -620,44 +929,31 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
if self.apply_residual_connection_post_layernorm:
assert ln_out is not None
residual = ln_out
query = checkpoint_name(query, 'query_proj')
key = checkpoint_name(key, 'key_proj')
value = checkpoint_name(value, 'value_proj')
qkv_layout = QKVLayout.BSHD_BSHD_BSHD
if self.enable_rotary_pos_emb:
if self.fuse_qkv and use_fused_attn:
if is_qkvpack:
if qkv_layout == QKVLayout.BS3HD:
query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
else:
elif qkv_layout == QKVLayout.BSHD_BS2HD:
key, value = jnp.split(kv_proj, [1], axis=-2)
else:
assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
query = rotary_pos_emb(query, self.rotary_pos_emb_windows,
self.transpose_batch_sequence)
key = rotary_pos_emb(key, self.rotary_pos_emb_windows, self.transpose_batch_sequence)
qkv_layout = QKVLayout.BSHD_BSHD_BSHD
if use_fused_attn:
if is_qkvpack:
qkv_proj = jnp.concatenate([query, key, value], axis=-2)
else:
kv_proj = jnp.concatenate([key, value], axis=-2)
if not use_fused_attn:
query = checkpoint_name(query, 'query_proj')
key = checkpoint_name(key, 'key_proj')
value = checkpoint_name(value, 'value_proj')
query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
if qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
qkv_sharding_constraint = \
(SEQLEN_AXES, BATCH_AXES, HEAD_AXES, HIDDEN_AXES) \
if self.transpose_batch_sequence \
else (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
if decode:
assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
is_initialized = self.has_variable('cache', 'cached_key')
cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype)
......@@ -667,12 +963,12 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
if self.transpose_batch_sequence:
length, batch, num_heads, head_dim = cached_key.value.shape
expected_shape = (1, batch, num_heads, head_dim)
length, batch, num_attention_heads, head_dim = cached_key.value.shape
expected_shape = (1, batch, num_attention_heads, head_dim)
one_hot_indices_shape = (length, 1, 1, 1)
else:
batch, length, num_heads, head_dim = cached_key.value.shape
expected_shape = (batch, 1, num_heads, head_dim)
batch, length, num_attention_heads, head_dim = cached_key.value.shape
expected_shape = (batch, 1, num_attention_heads, head_dim)
one_hot_indices_shape = (1, length, 1, 1)
# Sanity shape check of cached key against input query.
......@@ -694,100 +990,58 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length)))
if bias is not None:
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim,
in_axes=(None, 0, None, None))
bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
jnp.reshape(cur_index, (-1)), 1, -2)
scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
dropout_rng = None
if not deterministic and self.dropout_rate > 0.:
dropout_rng = self.make_rng(self.dropout_rng_name)
if use_fused_attn:
assert mask is not None and mask.ndim == 4 # (b, 1, s_q, s_kv)
assert not self.transpose_batch_sequence
seed = None
if dropout_rng is not None:
seed = jax.random.split(dropout_rng, num_of_devices())
# ensure the old key never used
del dropout_rng
if is_qkvpack:
qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
HIDDEN_AXES)
qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj,
qkv_sharding_constraint)
x = self_fused_attn(qkv_proj,
bias,
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor,
dropout_probability=self.dropout_rate,
is_training=not deterministic)
else:
assert bias is None
query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim))
kv_proj = kv_proj.reshape((*kv_proj.shape[:-1], self.num_gqa_groups, self.head_dim))
q_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
kv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
HIDDEN_AXES)
LEADING_AXES = (BATCH_AXES, SEQLEN_AXES)
if self.transpose_batch_sequence:
LEADING_AXES = (SEQLEN_AXES, BATCH_AXES)
if qkv_layout == QKVLayout.BS3HD:
qkv_proj = qkv_proj.reshape(*qkv_proj.shape[:2], 3, self.num_attention_heads,
self.head_dim)
qkv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, qkv_sharding_constraint)
dpa_args = [qkv_proj, None, None]
elif qkv_layout == QKVLayout.BSHD_BS2HD:
query = query.reshape(*query.shape[:2], self.num_attention_heads, self.head_dim)
kv_proj = kv_proj.reshape(*kv_proj.shape[:2], 2, self.num_gqa_groups, self.head_dim)
q_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
kv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint)
kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint)
x = cross_fused_attn(query,
kv_proj,
bias,
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor,
dropout_probability=self.dropout_rate,
is_training=not deterministic)
dpa_args = [query, kv_proj, None]
else:
assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
qkv_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
dpa_args = [query, key, value]
def convert_to_softmax_type(attn_mask_type, mask):
"""
Convert the string to SoftmaxType
"""
if attn_mask_type == 'causal':
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED
if attn_mask_type == 'padding':
if mask is not None:
return SoftmaxType.SCALED_MASKED
return SoftmaxType.SCALED
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type = {'causal', 'padding'}")
softmax_type = convert_to_softmax_type(self.attn_mask_type, mask)
x = core_attention(query,
key,
value,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
softmax_type=softmax_type,
mask=mask,
bias=bias,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
deterministic=deterministic,
x = DotProductAttention(head_dim=self.head_dim,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
attention_dropout=self.attention_dropout,
dtype=self.dtype,
float32_logits=self.float32_logits)
x = checkpoint_name(x, 'context')
dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_logits,
qkv_layout=qkv_layout.name,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence)(
*dpa_args, mask, bias, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
attn_context_sharding_constraint = \
(SEQLEN_AXES, BATCH_AXES, HIDDEN_TP_AXES) \
if self.transpose_batch_sequence \
else (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES)
x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint)
out = DenseGeneral(features=inputs_q.shape[-1],
......@@ -801,7 +1055,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
dtype=self.dtype,
name='out')(x)
out = checkpoint_name(out, 'out_proj')
return out, residual
return out, ln_out
class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-methods
......@@ -810,21 +1065,21 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-met
Parameters
----------
num_buckets : int
num_buckets: int
The number of buckets to bucket distances between key and query positions into.
max_distance : int
max_distance: int
The maximum distance before everything is lumped into the last
distance bucket.
num_attention_heads : int
num_attention_heads: int
Number of attention heads in the transformer layer.
embedding_init : Initializer, default = flax.linen.linear.default_embed_init
embedding_init: Initializer, default = flax.linen.linear.default_embed_init
Used for initializing relative embedding tables.
embedding_axes : Tuple[str, ...], default = ('heads', 'relpos_buckets')
embedding_axes: Tuple[str, ...], default = ('heads', 'relpos_buckets')
The name of axes used to shard embedding attention bias with a corresponding mesh.
Optimization parameters
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
"""
num_buckets: int
......@@ -841,11 +1096,11 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-met
Parameters
----------
q_seqlen : int
q_seqlen: int
The sequence length of query.
k_seqlen : int
k_seqlen: int
The sequence length of key.
bidirectional : bool, default = True
bidirectional: bool, default = True
Indicate whether to allow positive memory-query relative position
embeddings.
......@@ -917,11 +1172,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
an attention block and a feedforward network (MLP).
This standard layer is based on the paper “Attention Is All You Need”.
.. note::
Argument :attr:`attention_mask` will be ignored when
:attr:`self_attn_mask_type` is set to `"causal"`.
Parameters
----------
hidden_size: int, default = 512
......@@ -930,7 +1180,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Intermediate size to which input samples are projected.
num_attention_heads: int, default = 8
Number of attention heads in the transformer layer.
num_gqa_groups : int, default = `None`
num_gqa_groups: int, default = `None`
Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
......@@ -938,11 +1188,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
Indicate the type of layer normalization.
layernorm_epsilon: float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma : bool, default = False
zero_centered_gamma: bool, default = False
If set to `True`, the LayerNorm formula changes to
.. math::
......@@ -989,14 +1239,21 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
after the final dropout-add. default behavior is to apply layer
normalization on the input side, before the QKV transformation.
float32_attention_logits: bool, default = False
If set to True, attention logits are executed in jax.numpy.float32.
Whether to compute attention logits in float32 for the unfused attention backend.
For fused attention backend, the accumulation is always float32 without the perf overhead.
layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
If set to TransformerLayerType.DECODER, an additional cross-attention block
is added after self-attention.this can be used for structures like `T5`
Transformer in conjunction with the TransformerLayerType.ENCODER option.
self_attn_mask_type: {'causal', 'padding'}, default = 'causal'
Type of attention mask passed into softmax operation.
self_attn_mask_type: str, default = 'causal'
Type of the attention mask passed into softmax operation in the self attention.
Available options: {'no_mask', 'padding', 'causal', 'causal_padding'}
Introduced in v0.10.0.
self_attn_bias_type: Optional[str], default = None
Type of the attention bias passed into the self attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
When default is present, the type is automatically decided by the MHA's bias parameter.
Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
enable_relative_embedding: bool, default = True
Whether to enable relative embedding as shifting of attention logits.
relative_embedding: flax.linen.Module, default = None
......@@ -1017,7 +1274,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters
-----------------------
dtype :jax.numpy.dtype, default = jax.numpy.float32
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
drop_path: float, default = 0.0
When > 0.0, applies stochastic depth per sample in the main
......@@ -1026,7 +1283,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
If set to True, `TransformerLayer` module exposes a single fused
parameter for query-key-value for self-attention and key-value for
cross-attention.
transpose_batch_sequence : bool, default = False
transpose_batch_sequence: bool, default = False
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
......@@ -1041,7 +1298,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
hidden_size: int = 512
mlp_hidden_size: int = 2048
num_attention_heads: int = 8
num_gqa_groups: int | None = None
num_gqa_groups: Optional[int] = None
layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
......@@ -1061,6 +1318,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = 'causal'
self_attn_bias_type: Optional[str] = None
enable_relative_embedding: bool = True
relative_embedding: nn.Module = None
enable_rotary_pos_emb: bool = False
......@@ -1097,29 +1355,29 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Parameters
----------
inputs : jax.numpy.ndarray
inputs: jax.numpy.ndarray
Input tensor.
encoded : jax.numpy.ndarray, default = None
encoded: jax.numpy.ndarray, default = None
Output tensors of the encoder block to be fed into the decoder block if using
:attr:`layer_type=TransformerLayerType.DECODER`.
attention_mask : jax.numpy.ndarray, default = None
Boolean tensor used to mask out self-attention softmax input.
encoder_decoder_mask : jax.numpy.ndarray, default = None
encoder_decoder_mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out cross-attention softmax input when
:attr:`layer_type=TransformerLayerType.DECODER`.
deterministic: bool, default = False
Disable dropout layers if set to True.
decode: bool,default = False
decode: bool, default = False
Indicate whether to prepare and use an autoregressive cache
in Multi-head attention (MHA).
max_decode_length : bool, default = None
max_decode_length: bool, default = None
The maximum length to generate relative embedding biases when
:attr:`layer_type=TransformerLayerType.DECODER` and
:attr:`enable_relative_embedding=True`.
Returns
-------
outputs : jax.numpy.ndarray
outputs: jax.numpy.ndarray
Output tensors.
"""
assert self.layer_type in TransformerLayerType, \
......@@ -1184,14 +1442,15 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
# [batch, length, emb_dim] -> [batch, length, emb_dim]
x, residual = MultiHeadAttention(
num_heads=self.num_attention_heads,
residual = inputs
x, ln_out = MultiHeadAttention(
num_attention_heads=self.num_attention_heads,
dtype=self.dtype,
head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
dropout_rate=self.attention_dropout,
attention_dropout=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
......@@ -1199,12 +1458,13 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm,
return_layernorm_output=self.apply_residual_connection_post_layernorm,
input_layernorm=not self.output_layernorm,
attn_mask_type=self.self_attn_mask_type,
attn_bias_type=self.self_attn_bias_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
fuse_qkv=self.fuse_qkv_params,
fuse_qkv_params=self.fuse_qkv_params,
kernel_init=self.mha_kernel_init,
use_bias=self.use_bias,
bias_init=self.bias_init,
......@@ -1236,6 +1496,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
x = nn.Dropout(rate=self.drop_path,
broadcast_dims=drop_path_shape,
rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
if self.apply_residual_connection_post_layernorm:
assert ln_out is not None
residual = ln_out
x = x + residual
mlp_input = x
......@@ -1246,28 +1511,29 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
x = with_sharding_constraint_by_logical_axes(
x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
y, residual = MultiHeadAttention(
num_heads=self.num_attention_heads,
residual = x
y, ln_out = MultiHeadAttention(
num_attention_heads=self.num_attention_heads,
dtype=self.dtype,
head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
dropout_rate=self.attention_dropout,
attention_dropout=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name,
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
apply_residual_connection_post_layernorm=self.
apply_residual_connection_post_layernorm,
output_layernorm=False, # Must do LayerNorm before MHA.
return_layernorm_output=self.apply_residual_connection_post_layernorm,
input_layernorm=True, # Must do LayerNorm before MHA.
attn_mask_type='padding',
attn_bias_type='no_bias',
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
fuse_qkv=self.fuse_qkv_params,
fuse_qkv_params=self.fuse_qkv_params,
kernel_init=self.mha_kernel_init,
use_bias=self.use_bias,
bias_init=self.bias_init,
......@@ -1282,6 +1548,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
y = hidden_dropout(y, deterministic)
if self.apply_residual_connection_post_layernorm:
assert ln_out is not None
residual = ln_out
mlp_input = y + residual
mlp_input = with_sharding_constraint_by_logical_axes(
......@@ -1342,6 +1613,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_NO_SHARD_AXES,),
transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype,
name="output_layer_norm")(z)
name="output_layernorm")(z)
return z
......@@ -16,6 +16,7 @@ from transformer_engine_jax import NVTE_QKV_Layout
from .cpp_extensions import FusedAttnHelper
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
from .cpp_extensions import fused_attn_fwd, fused_attn_bwd
class AttnBiasType(Enum):
......@@ -37,6 +38,21 @@ class QKVLayout(Enum):
"""QKV layout"""
BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD
BSHD_BSHD_BSHD = NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD
def canonicalize_attn_mask_type(attn_mask_type: str):
"""Convert string attn_mask_type to AttnMaskType
TE-JAX currently fall back to the padding version kernels for the libraries integration.
The overhead between padding and non-padding version should be small.
However, we will lease this limitation in the near feature.
"""
if attn_mask_type in ['causal', 'padding_causal']:
return AttnMaskType.PADDING_CAUSAL_MASK
if attn_mask_type in ['no_mask', 'padding']:
return AttnMaskType.PADDING_MASK
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type={'no_mask', 'padding', 'causal', 'padding_causal'}")
def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type,
......@@ -83,6 +99,10 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda
seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
if mask is None:
batch, seqlen, *_ = qkv.shape
actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32)
else:
mask = jnp.logical_not(mask)
actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
......@@ -159,13 +179,18 @@ def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask:
def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
if mask is None:
batch, s_q, *_ = q.shape
s_kv = kv.shape[1]
q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
else:
mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else:
# When mask is padding + causal, the actual seqlen is not the last row, use max to find it
# When mask is causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = cross_fused_attn_fwd(q,
......@@ -179,7 +204,9 @@ def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mas
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output = checkpoint_name(output, 'context')
softmax_aux = checkpoint_name(softmax_aux, 'context')
rng_state = checkpoint_name(rng_state, 'context')
return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen)
......@@ -209,3 +236,100 @@ def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, d
_cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule)
def fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Dot product attention with the seperated query, key, value
"""
output = _fused_attn(q,
k,
v,
bias,
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10))
def _fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray,
mask: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float, dropout_probability: float,
is_training: bool):
output, _ = _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training)
return output
def _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
if mask is None:
batch, s_q, *_ = q.shape
s_kv = k.shape[1]
q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
else:
mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else:
# When mask is causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = fused_attn_fwd(q,
k,
v,
bias,
q_actual_seqlen,
kv_actual_seqlen,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output = checkpoint_name(output, 'context')
softmax_aux = checkpoint_name(softmax_aux, 'context')
rng_state = checkpoint_name(rng_state, 'context')
return output, (q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen,
kv_actual_seqlen)
def _fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz):
q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
grad_q, grad_k, grad_v, grad_bias = fused_attn_bwd(q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
dz,
q_actual_seqlen,
kv_actual_seqlen,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None
return grad_q, grad_k, grad_v, grad_bias, None, None
_fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule)
......@@ -4,5 +4,6 @@
"""Praxis related Modules"""
from .module import FusedSoftmax, LayerNorm
from .module import LayerNormLinear, LayerNormMLP, Linear, TransformerEngineBaseLayer
from .transformer import MultiHeadAttention, RelativePositionBiases, TransformerLayer
from .transformer import DotProductAttention, MultiHeadAttention
from .transformer import RelativePositionBiases, TransformerLayer
from ..flax.transformer import TransformerLayerType
......@@ -6,6 +6,7 @@ Praxis Modules related Transformer
"""
from functools import partial
from typing import Optional, Sequence, Tuple
import warnings
from praxis import pax_fiddle
from praxis.base_layer import WeightInit
......@@ -13,9 +14,11 @@ from praxis.pytypes import JTensor
from .module import TransformerEngineBaseLayer
from ..flax.transformer import TransformerLayerType
from ..flax.transformer import DotProductAttention as flax_DotProductAttention
from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention
from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases
from ..flax.transformer import TransformerLayer as flax_TransformerLayer
from ..fused_attn import AttnBiasType, AttnMaskType
class RelativePositionBiases(TransformerEngineBaseLayer):
......@@ -59,30 +62,117 @@ class RelativePositionBiases(TransformerEngineBaseLayer):
return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional)
class DotProductAttention(TransformerEngineBaseLayer):
"""DotProductAttention"""
head_dim: int = 0
num_attention_heads: int = 0
num_gqa_groups: Optional[int] = None
attention_dropout: float = 0.
attn_mask_type: AttnMaskType = 'causal'
attn_bias_type: AttnBiasType = None
dropout_rng_name: str = 'dropout'
float32_logits: bool = False
qkv_layout: str = 'bshd_bshd_bshd'
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
def setup(self) -> None:
"""setup"""
super().setup()
assert self.head_dim > 0, f'{self.head_dim=}'
assert self.num_attention_heads > 0, f'{self.num_attention_heads=}'
dpa_cls = partial(flax_DotProductAttention,
head_dim=self.head_dim,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
attention_dropout=self.attention_dropout,
dtype=self.dtype,
dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_logits,
qkv_layout=self.qkv_layout,
scale_factor=self.scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence)
self.create_layer("dot_product_attention", dpa_cls)
def __call__(self,
query: JTensor,
key: JTensor,
value: JTensor,
mask: Optional[JTensor] = None,
bias: Optional[JTensor] = None,
*,
deterministic: bool = False) -> JTensor:
"""__call__"""
return self.dot_product_attention(query,
key,
value,
mask,
bias,
deterministic=deterministic)
class MultiHeadAttention(TransformerEngineBaseLayer):
"""MultiHeadAttention"""
head_dim: int = 64
num_heads: int = 16
num_gqa_groups: int | None = None
dropout_rate: float = 0.
head_dim: int = 0
num_attention_heads: int = 0
num_gqa_groups: Optional[int] = None
attention_dropout: float = 0.
dropout_rng_name: str = 'dropout'
input_layernorm: bool = True
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
return_layernorm_output: bool = False
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
attn_mask_type: str = 'causal'
fuse_qkv: bool = True
attn_bias_type: Optional[str] = None
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
float32_logits: bool = False
# Deprecated parameters
num_heads: Optional[int] = None
dropout_rate: Optional[float] = None
output_layernorm: Optional[bool] = None
apply_residual_connection_post_layernorm: Optional[bool] = None
fuse_qkv: Optional[bool] = None
def __post_init__(self):
# Deal with the deprecated parameters
if self.num_heads is not None:
self.num_attention_heads = self.num_heads
warnings.warn(
f"{__class__}.num_heads is deprecated. It will be removed recently. "
f"Please uses {__class__}.num_attention_heads as the new API.", DeprecationWarning)
if self.dropout_rate is not None:
self.attention_dropout = self.dropout_rate
warnings.warn(
f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
f"Please use {__class__}.attention_dropout as the new API.", DeprecationWarning)
if self.apply_residual_connection_post_layernorm is not None:
warnings.warn(
f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
f"It will be removed recently, please use {__class__}.return_layernorm_output.",
DeprecationWarning)
if self.fuse_qkv is not None:
warnings.warn(
f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
f"Please use {__class__}.fuse_qkv_params as the new API.", DeprecationWarning)
assert self.output_layernorm is None, (
f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm.")
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_heads
super().__post_init__()
......@@ -91,24 +181,28 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
"""setup"""
super().setup()
assert self.head_dim > 0, f'{self.head_dim=}'
assert self.num_attention_heads > 0, f'{self.num_attention_heads=}'
mha_cls = partial(
flax_MultiHeadAttention,
dtype=self.dtype,
head_dim=self.head_dim,
num_heads=self.num_heads,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dropout_rate=self.dropout_rate,
attention_dropout=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name,
input_layernorm=self.input_layernorm,
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
return_layernorm_output=self.return_layernorm_output,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm,
attn_mask_type=self.attn_mask_type,
fuse_qkv=self.fuse_qkv,
attn_bias_type=self.attn_bias_type,
fuse_qkv_params=self.fuse_qkv_params,
transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits,
......@@ -140,7 +234,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
hidden_size: int = 512
mlp_hidden_size: int = 2048
num_attention_heads: int = 8
num_gqa_groups: int | None = None
num_gqa_groups: Optional[int] = None
layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
......@@ -158,6 +252,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = 'causal'
self_attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
enable_relative_embedding: bool = True
......@@ -226,6 +321,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
float32_attention_logits=self.float32_attention_logits,
layer_type=self.layer_type,
self_attn_mask_type=self.self_attn_mask_type,
self_attn_bias_type=self.self_attn_bias_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
enable_relative_embedding=self.enable_relative_embedding,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment