Unverified Commit 9b2fed51 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Refine MHA API and add DPA API (#653)



* Refine MHA API
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Reuse func from the flax
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* DPA draft
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* qkv packed draft
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix test_layer with fused attn
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add attn_bias_type and enhance a few code flow
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Move scale_factor from __call__ to init
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance the docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add DPA public API and tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix conflict
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add qkv separate fused attn
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Apply BSHD_BSHD_BSHD format
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove debug log
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add fused attention layer tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add NVTE_FUSED_ATTN docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fine-grained fused attn settings
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove the default value of num_attetnion_head and head_dim
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add teardown for fused attn env
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Unify the Optional notation
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix Pre/Post scale bias comments
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add no_mask tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add checkpoint_name for fused attn
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix the fused attn batcher
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent fb2f952a
...@@ -45,6 +45,9 @@ Modules ...@@ -45,6 +45,9 @@ Modules
.. autoapiclass:: transformer_engine.jax.flax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs) .. autoapiclass:: transformer_engine.jax.flax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs)
:members: __call__ :members: __call__
.. autoapiclass:: transformer_engine.jax.flax.DotProductAttention(head_dim, num_heads, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.flax.MultiHeadAttention(head_dim, num_heads, **kwargs) .. autoapiclass:: transformer_engine.jax.flax.MultiHeadAttention(head_dim, num_heads, **kwargs)
:members: __call__ :members: __call__
......
...@@ -20,7 +20,7 @@ from jax import value_and_grad, jit ...@@ -20,7 +20,7 @@ from jax import value_and_grad, jit
from jax.typing import ArrayLike, DTypeLike from jax.typing import ArrayLike, DTypeLike
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn, fused_attn
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
...@@ -144,6 +144,9 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng ...@@ -144,6 +144,9 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
kv = jnp.concatenate((key, value), axis=-3) kv = jnp.concatenate((key, value), axis=-3)
return cross_fused_attn(query, kv, bias, mask, dropout_rng, return cross_fused_attn(query, kv, bias, mask, dropout_rng,
**kwargs).astype(query.dtype) **kwargs).astype(query.dtype)
case QKVLayout.BSHD_BSHD_BSHD:
return fused_attn(query, key, value, bias, mask, dropout_rng,
**kwargs).astype(query.dtype)
@dataclass @dataclass
...@@ -337,6 +340,7 @@ class FusedAttnRunner: ...@@ -337,6 +340,7 @@ class FusedAttnRunner:
@pytest.mark.parametrize('qkv_layout', [ @pytest.mark.parametrize('qkv_layout', [
pytest.param(QKVLayout.BS3HD, id='qkvpacked'), pytest.param(QKVLayout.BS3HD, id='qkvpacked'),
pytest.param(QKVLayout.BSHD_BS2HD, id='kvpacked'), pytest.param(QKVLayout.BSHD_BS2HD, id='kvpacked'),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id='separate'),
]) ])
@pytest.mark.parametrize('dropout_prob', [0., 0.1]) @pytest.mark.parametrize('dropout_prob', [0., 0.1])
@pytest.mark.parametrize('is_training', @pytest.mark.parametrize('is_training',
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import os
from functools import partial from functools import partial
import flax import flax
...@@ -20,6 +21,16 @@ from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available ...@@ -20,6 +21,16 @@ from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
@pytest.fixture(autouse=True, scope='module')
def enable_fused_attn():
"""
Enable fused attention
"""
os.environ["NVTE_FUSED_ATTN"] = "1"
yield
del os.environ["NVTE_FUSED_ATTN"]
@pytest.fixture(autouse=True, scope='function') @pytest.fixture(autouse=True, scope='function')
def clear_live_arrays(): def clear_live_arrays():
""" """
...@@ -93,6 +104,7 @@ _KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb" ...@@ -93,6 +104,7 @@ _KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb"
BASE_ATTRS = { BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True, _KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_NUM_HEADS: 8, _KEY_OF_NUM_HEADS: 8,
_KEY_OF_DROPOUT_RATE: 0,
} }
ATTRS = [{ ATTRS = [{
...@@ -221,6 +233,7 @@ class TestEncoderLayer: ...@@ -221,6 +233,7 @@ class TestEncoderLayer:
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng) ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng) test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
del data_rng, init_rng, apply_rng del data_rng, init_rng, apply_rng
...@@ -282,9 +295,6 @@ class TestEncoderLayer: ...@@ -282,9 +295,6 @@ class TestEncoderLayer:
test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer, test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer,
apply_rng) apply_rng)
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
def reorganize_test_wgrad(test_wgrad, attrs): def reorganize_test_wgrad(test_wgrad, attrs):
num_heads = attrs.get(_KEY_OF_NUM_HEADS) num_heads = attrs.get(_KEY_OF_NUM_HEADS)
num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads) num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
...@@ -328,6 +338,10 @@ class TestEncoderLayer: ...@@ -328,6 +338,10 @@ class TestEncoderLayer:
del unfreeze_test_wgrad['mlp']['wo_kernel'] del unfreeze_test_wgrad['mlp']['wo_kernel']
return unfreeze_test_wgrad return unfreeze_test_wgrad
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
compare_dict(ref_grads[1], compare_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs), reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol, rtol=rtol,
...@@ -430,6 +444,7 @@ class TestDecoderLayer: ...@@ -430,6 +444,7 @@ class TestDecoderLayer:
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng) ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng) test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
del data_rng, init_rng, apply_rng del data_rng, init_rng, apply_rng
...@@ -492,9 +507,6 @@ class TestDecoderLayer: ...@@ -492,9 +507,6 @@ class TestDecoderLayer:
test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer, test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer,
apply_rng) apply_rng)
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
def reorganize_test_wgrad(test_wgrad, attrs): def reorganize_test_wgrad(test_wgrad, attrs):
num_heads = attrs.get(_KEY_OF_NUM_HEADS) num_heads = attrs.get(_KEY_OF_NUM_HEADS)
num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads) num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
...@@ -547,6 +559,9 @@ class TestDecoderLayer: ...@@ -547,6 +559,9 @@ class TestDecoderLayer:
del unfreeze_test_wgrad['mlp']['wo_kernel'] del unfreeze_test_wgrad['mlp']['wo_kernel']
return unfreeze_test_wgrad return unfreeze_test_wgrad
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
compare_dict(ref_grads[1], compare_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs), reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol, rtol=rtol,
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import os
from functools import partial from functools import partial
from typing import Dict from typing import Dict
...@@ -14,12 +15,14 @@ import pytest ...@@ -14,12 +15,14 @@ import pytest
from utils import assert_allclose from utils import assert_allclose
from transformer_engine_jax import get_device_compute_capability
from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax import fp8_autocast, update_fp8_metas, update_collections from transformer_engine.jax import fp8_autocast, update_fp8_metas, update_collections
from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral
from transformer_engine.jax.flax import LayerNorm as flax_LayerNorm from transformer_engine.jax.flax import LayerNorm as flax_LayerNorm
from transformer_engine.jax.flax import LayerNormMLP as flax_LayerNormMLP from transformer_engine.jax.flax import LayerNormMLP as flax_LayerNormMLP
from transformer_engine.jax.flax import MultiHeadAttention as flax_MultiHeadAttention from transformer_engine.jax.flax import MultiHeadAttention as flax_MultiHeadAttention
from transformer_engine.jax.flax import DotProductAttention as flax_DotProductAttention
from transformer_engine.jax.flax import RelativePositionBiases as flax_RelativePositionBiases from transformer_engine.jax.flax import RelativePositionBiases as flax_RelativePositionBiases
from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLayer from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLayer
from transformer_engine.jax.flax.module import Softmax from transformer_engine.jax.flax.module import Softmax
...@@ -27,8 +30,8 @@ from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available ...@@ -27,8 +30,8 @@ from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from transformer_engine.jax.praxis import LayerNorm from transformer_engine.jax.praxis import LayerNorm
from transformer_engine.jax.praxis import FusedSoftmax from transformer_engine.jax.praxis import FusedSoftmax
from transformer_engine.jax.praxis import LayerNormLinear, LayerNormMLP, Linear from transformer_engine.jax.praxis import LayerNormLinear, LayerNormMLP, Linear
from transformer_engine.jax.praxis import MultiHeadAttention, RelativePositionBiases from transformer_engine.jax.praxis import DotProductAttention, MultiHeadAttention
from transformer_engine.jax.praxis import TransformerEngineBaseLayer from transformer_engine.jax.praxis import RelativePositionBiases, TransformerEngineBaseLayer
from transformer_engine.jax.praxis import TransformerLayer, TransformerLayerType from transformer_engine.jax.praxis import TransformerLayer, TransformerLayerType
from transformer_engine.jax.softmax import SoftmaxType from transformer_engine.jax.softmax import SoftmaxType
...@@ -40,6 +43,19 @@ ENABLE_FP8 = [False, True] ...@@ -40,6 +43,19 @@ ENABLE_FP8 = [False, True]
FP8_FORMATS = [Format.E4M3, Format.HYBRID] FP8_FORMATS = [Format.E4M3, Format.HYBRID]
@pytest.fixture(autouse=True, scope='module')
def enable_fused_attn():
"""
Enable fused attn for hopper+ arch.
Fused attn kernels on pre-hopper arch are not deterministic.
"""
if get_device_compute_capability(0) >= 90:
os.environ["NVTE_FUSED_ATTN"] = "1"
yield
if "NVTE_FUSED_ATTN" in os.environ:
del os.environ["NVTE_FUSED_ATTN"]
@pytest.fixture(autouse=True, scope='function') @pytest.fixture(autouse=True, scope='function')
def clear_live_arrays(): def clear_live_arrays():
""" """
...@@ -101,6 +117,7 @@ class TestLayer: ...@@ -101,6 +117,7 @@ class TestLayer:
lyr_name = self.get_layer_name() lyr_name = self.get_layer_name()
if 'params' in flax_variables:
synced_praxis_variables['params'][lyr_name]['cld'] = \ synced_praxis_variables['params'][lyr_name]['cld'] = \
flax.core.unfreeze(flax_variables['params']) flax.core.unfreeze(flax_variables['params'])
...@@ -111,6 +128,7 @@ class TestLayer: ...@@ -111,6 +128,7 @@ class TestLayer:
lyr_name = self.get_layer_name() lyr_name = self.get_layer_name()
if 'params' in synced_praxis_grads:
synced_praxis_grads['params'] = \ synced_praxis_grads['params'] = \
synced_praxis_grads['params'][lyr_name]['cld'] synced_praxis_grads['params'][lyr_name]['cld']
...@@ -671,6 +689,86 @@ class TestRelativePositionBias(TestLayer): ...@@ -671,6 +689,86 @@ class TestRelativePositionBias(TestLayer):
assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol) assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)
class DotProductAttnAttr:
ATTN_MASK_TYPE = 'attn_mask_type'
NUM_GQA_GROUPS = 'num_gqa_groups'
TRANSPOSE_BS = 'transpose_batch_sequence'
SCALE_FACTOR = 'scale_factor'
ATTRS = [{
ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125,
}, {
ATTN_MASK_TYPE: 'padding_causal',
TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125,
}, {
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125,
}, {
ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: False,
SCALE_FACTOR: 0.125,
}, {
ATTN_MASK_TYPE: 'padding_causal',
TRANSPOSE_BS: False,
SCALE_FACTOR: 2.,
}, {
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: False,
SCALE_FACTOR: 1.,
}, {
ATTN_MASK_TYPE: 'no_mask',
TRANSPOSE_BS: False,
SCALE_FACTOR: 1.,
}]
class TestDotProductAttn(TestLayer):
def input_getter(self, shape, dtype):
key = jax.random.PRNGKey(seed=1234)
q_key, k_key, v_key = jax.random.split(key, 3)
return list(map(partial(jax.random.normal, shape=shape, dtype=dtype),
[q_key, k_key, v_key]))
def get_layer_name(self):
return 'dot_product_attn'
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
head_dim = 64
num_attention_heads = 16
num_gqa_groups = num_attention_heads
attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE]
transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS]
praxis_p = pax_fiddle.Config(DotProductAttention,
name='mha',
dtype=dtype,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence)
flax_cls = partial(flax_DotProductAttention,
dtype=dtype,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence)
return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', [(32, 128, 16, 64)])
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class MultiHeadAttnAttr: class MultiHeadAttnAttr:
USE_BIAS = 'use_bias' USE_BIAS = 'use_bias'
LN_TYPE = 'layernorm_type' LN_TYPE = 'layernorm_type'
...@@ -730,36 +828,38 @@ class TestMultiHeadAttn(TestLayer): ...@@ -730,36 +828,38 @@ class TestMultiHeadAttn(TestLayer):
def generate_praxis_p_and_flax_cls(self, dtype, attrs): def generate_praxis_p_and_flax_cls(self, dtype, attrs):
head_dim = 64 head_dim = 64
num_heads = 16 num_attention_heads = 16
num_gqa_groups = attrs[MultiHeadAttnAttr.NUM_GQA_GROUPS] \
if MultiHeadAttnAttr.NUM_GQA_GROUPS in attrs else None
layernorm_type = attrs[MultiHeadAttnAttr.LN_TYPE] layernorm_type = attrs[MultiHeadAttnAttr.LN_TYPE]
zero_centered_gamma = attrs[MultiHeadAttnAttr.ZERO_CEN] zero_centered_gamma = attrs[MultiHeadAttnAttr.ZERO_CEN]
kernel_init = WeightInit.Gaussian(1.0) kernel_init = WeightInit.Gaussian(1.0)
use_bias = attrs[MultiHeadAttnAttr.USE_BIAS] use_bias = attrs[MultiHeadAttnAttr.USE_BIAS]
bias_init = WeightInit.Constant(0.0) bias_init = WeightInit.Constant(0.0)
apply_residual_connection_post_layernorm = False input_layernorm = False
output_layernorm = False return_layernorm_output = False
attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE] attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE]
fuse_qkv: bool = True fuse_qkv_params = True
transpose_batch_sequence = True transpose_batch_sequence = True
scale_attn_logits = False scale_attn_logits = False
scaled_query_init = True scaled_query_init = True
float32_logits = False float32_logits = False
praxis_p = pax_fiddle.Config( praxis_p = pax_fiddle.Config(MultiHeadAttention,
MultiHeadAttention,
name='mha', name='mha',
dtype=dtype, dtype=dtype,
head_dim=head_dim, head_dim=head_dim,
num_heads=num_heads, num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
params_init=kernel_init, params_init=kernel_init,
use_bias=use_bias, use_bias=use_bias,
bias_init=bias_init, bias_init=bias_init,
apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm, return_layernorm_output=return_layernorm_output,
output_layernorm=output_layernorm, input_layernorm=input_layernorm,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
fuse_qkv=fuse_qkv, fuse_qkv_params=fuse_qkv_params,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits, scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init, scaled_query_init=scaled_query_init,
...@@ -768,16 +868,17 @@ class TestMultiHeadAttn(TestLayer): ...@@ -768,16 +868,17 @@ class TestMultiHeadAttn(TestLayer):
flax_MultiHeadAttention, flax_MultiHeadAttention,
dtype=dtype, dtype=dtype,
head_dim=head_dim, head_dim=head_dim,
num_heads=num_heads, num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init), kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
use_bias=use_bias, use_bias=use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init), bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm, return_layernorm_output=return_layernorm_output,
output_layernorm=output_layernorm, input_layernorm=input_layernorm,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
fuse_qkv=fuse_qkv, fuse_qkv_params=fuse_qkv_params,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits, scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init, scaled_query_init=scaled_query_init,
...@@ -1024,6 +1125,7 @@ class TestTransformer(TestLayer): ...@@ -1024,6 +1125,7 @@ class TestTransformer(TestLayer):
enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE] enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE]
enable_relative_embedding = True enable_relative_embedding = True
relative_embedding = pax_fiddle.Config(RelativePositionBiases, relative_embedding = pax_fiddle.Config(RelativePositionBiases,
dtype=dtype,
num_attention_heads=num_attention_heads) num_attention_heads=num_attention_heads)
drop_path = 0.0 drop_path = 0.0
transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS] transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS]
......
...@@ -934,7 +934,7 @@ class EncoderLayer(nn.Module): ...@@ -934,7 +934,7 @@ class EncoderLayer(nn.Module):
y = LayerNorm(layernorm_type=self.layernorm_type, y = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype, dtype=self.dtype,
name="output_layer_norm")(y) name="output_layernorm")(y)
return y return y
...@@ -1090,7 +1090,7 @@ class DecoderLayer(nn.Module): ...@@ -1090,7 +1090,7 @@ class DecoderLayer(nn.Module):
z = LayerNorm(layernorm_type=self.layernorm_type, z = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype, dtype=self.dtype,
name="output_layer_norm")(z) name="output_layernorm")(z)
return z return z
......
...@@ -105,8 +105,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -105,8 +105,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool flag_m512 = false; bool flag_m512 = false;
bool flag_arb = false; bool flag_arb = false;
if ((sm_arch_ == 80 || sm_arch_ == 90) if ((sm_arch_ == 80 || sm_arch_ == 90)
&& (max_seqlen_q <= 512) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0)
&& (max_seqlen_kv <= 512) && (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0)
&& (head_dim == 64) && (head_dim == 64)
&& (num_attn_heads == num_gqa_groups) && (num_attn_heads == num_gqa_groups)
&& ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
......
...@@ -1885,6 +1885,7 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -1885,6 +1885,7 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type, backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type,
attn_mask_type, dropout_probability, num_heads, num_heads, attn_mask_type, dropout_probability, num_heads, num_heads,
max_seqlen, max_seqlen, head_dim).get_fused_attn_backend() max_seqlen, max_seqlen, head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_shape = (*batch_shape, num_heads, max_seqlen, max_seqlen) softmax_shape = (*batch_shape, num_heads, max_seqlen, max_seqlen)
softmax_dtype = qkv_dtype softmax_dtype = qkv_dtype
...@@ -2029,7 +2030,7 @@ def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, seqlen: jnp.ndarray ...@@ -2029,7 +2030,7 @@ def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, seqlen: jnp.ndarray
scaling_factor: float, dropout_probability: float, is_training: bool): scaling_factor: float, dropout_probability: float, is_training: bool):
""" """
Wrapper for TE self fused attention fwd Wrapper for TE self fused attention fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2
""" """
checker = _FusedAttnRNGStateChecker() checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training) seed = checker.check_seed(seed, dropout_probability, is_training)
...@@ -2273,6 +2274,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2273,6 +2274,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
attn_bias_type, attn_mask_type, dropout_probability, num_heads, attn_bias_type, attn_mask_type, dropout_probability, num_heads,
num_gqa_groups, q_max_seqlen, kv_max_seqlen, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
q_head_dim).get_fused_attn_backend() q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, kv_max_seqlen) softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, kv_max_seqlen)
softmax_dtype = q_dtype softmax_dtype = q_dtype
...@@ -2426,7 +2428,7 @@ def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, q_s ...@@ -2426,7 +2428,7 @@ def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, q_s
dropout_probability: float, is_training: bool): dropout_probability: float, is_training: bool):
""" """
Wrapper for TE cross fused attention fwd Wrapper for TE cross fused attention fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2
""" """
checker = _FusedAttnRNGStateChecker() checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training) seed = checker.check_seed(seed, dropout_probability, is_training)
...@@ -2662,6 +2664,445 @@ def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, ...@@ -2662,6 +2664,445 @@ def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
is_training=is_training) is_training=is_training)
class FusedAttnFwdPrimitive(BasePrimitive):
"""
Fused Attention Forward Primitive
Query, key, value are seperated tensors
"""
name = "te_fused_attn_forward"
multiple_results = True
impl_static_args = (7, 8, 9, 10, 11)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(q_aval, k_aval, v_aval, bias_aval, q_seqlen_or_cu_seqlen_aval,
kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
"""
Fused attention fwd abstract
"""
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
assert q_dtype == k_dtype == v_dtype == bias_dtype
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, num_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert k_aval.shape == v_aval.shape
out_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
# backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(q_dtype, k_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD, attn_bias_type,
attn_mask_type, dropout_probability, num_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, kv_max_seqlen)
softmax_dtype = q_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f'Unsupported {backend=}')
softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)
# JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
# 32-bit unsigned int to get the buffer size we need in the C++ kernel
checker = _FusedAttnRNGStateChecker()
seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
assert seed_dtype == checker.rng_state_dtype
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
# prepare for the active fused-attn backend
batch_size = reduce(operator.mul, q_batch_shape)
wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
wkspace_aval = q_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Fused attention fwd outer primitive abstract
"""
out_aval, softmax_aux_aval, rng_state_aval, _ = \
FusedAttnFwdPrimitive.abstract(*args, **kwargs)
return out_aval, softmax_aux_aval, rng_state_aval
@staticmethod
def lowering(ctx, q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Fused attention fwd lowering rules
"""
operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, k_aval, v_aval, *_ = ctx.avals_in
*batch_shape, q_max_seqlen, num_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, num_gqa_groups, _ = k_aval.shape
assert k_aval.shape == v_aval.shape
batch_size = reduce(operator.mul, batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim,
wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(q, k, v, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
assert FusedAttnFwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind(
q,
k,
v,
bias,
q_cu_seqlen,
kv_cu_seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output, softmax_aux, rng_state
@staticmethod
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
_check_valid_batch_dims(batch_dims)
assert FusedAttnFwdPrimitive.outer_primitive is not None
q_bdim, *_, seed_bdim = batch_dims
out_bdims = q_bdim, q_bdim, seed_bdim
return FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training), out_bdims
@staticmethod
def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, mesh, arg_infos,
result_infos):
del attn_bias_type, attn_mask_type, scaling_factor
del dropout_probability, is_training, result_infos
q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden)
k_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3]))
rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
return (out_sharding, softmax_aux_sharding, rng_state_sharding)
@staticmethod
def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
mesh, arg_infos, result_infos):
del result_infos
q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden)
k_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3]))
rng_state_sharding = seed_sharding = NamedSharding(mesh,
PartitionSpec(get_all_mesh_axes(), None))
arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
impl = partial(FusedAttnFwdPrimitive.impl,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return mesh, impl, out_shardings, arg_shardings
register_primitive(FusedAttnFwdPrimitive)
def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray,
q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Wrapper for TE fused attention fwd, where query, key, value are seperated tensors
Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2
"""
checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
return FusedAttnFwdPrimitive.outer_primitive.bind(q,
k,
v,
bias,
q_seqlen,
kv_seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
class FusedAttnBwdPrimitive(BasePrimitive):
"""
Fused Attention Backward Primitive
"""
name = "te_fused_attn_backward"
multiple_results = True
impl_static_args = (10, 11, 12, 13, 14)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(q_aval, k_aval, v_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval,
doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Fused attention bwd abstract
"""
del softmax_aux_aval, rng_state_aval, output_aval
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, num_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert k_aval.shape == v_aval.shape
batch_size = reduce(operator.mul, q_batch_shape)
wkspace_shape, wkspace_dtype = \
transformer_engine_jax.get_fused_attn_bwd_workspace_sizes(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training
)
dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype)
dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype)
dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
wkspace_aval = q_aval.update(shape=wkspace_shape,
dtype=te_dtype_to_jax_dtype(wkspace_dtype))
return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Fused attention fwd outer primitive abstract
"""
dq_aval, dk_aval, dv_aval, dbias_aval, _ = \
FusedAttnBwdPrimitive.abstract(*args, **kwargs)
return dq_aval, dk_aval, dv_aval, dbias_aval
@staticmethod
def lowering(ctx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
kv_cu_seqlen, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
"""
Fused attention bwd lowering rules
"""
operands = [
q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, k_aval, v_aval, *_ = ctx.avals_in
*batch_shape, q_max_seqlen, num_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, num_gqa_groups, _ = k_aval.shape
assert k_aval.shape == v_aval.shape
batch_size = reduce(operator.mul, batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim,
wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen,
attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training):
assert FusedAttnBwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return dq, dk, dv, dbias
@staticmethod
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
_check_valid_batch_dims(batch_dims)
assert FusedAttnBwdPrimitive.outer_primitive is not None
q_bdim, k_bdim, v_bdim, *_ = batch_dims
out_bdims = q_bdim, k_bdim, v_bdim, q_bdim
return FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training), out_bdims
@staticmethod
def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, mesh, arg_infos,
result_infos):
del attn_bias_type, attn_mask_type, scaling_factor
del dropout_probability, is_training, result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
@staticmethod
def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
mesh, arg_infos, result_infos):
del result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
def sharded_impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
kv_cu_seqlen):
local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
global_dbias = local_dbias
if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
return local_dq, local_dk, local_dv, global_dbias
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(FusedAttnBwdPrimitive)
def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray,
softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray,
doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Wrapper for TE fused attention bwd
Return the gradients of fused attention with seperated query, key, value tensors
"""
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
return FusedAttnBwdPrimitive.outer_primitive.bind(q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
class GeluPrimitive(BasePrimitive): class GeluPrimitive(BasePrimitive):
""" """
Gelu Froward Primitive Gelu Froward Primitive
......
...@@ -53,6 +53,8 @@ pybind11::dict Registrations() { ...@@ -53,6 +53,8 @@ pybind11::dict Registrations() {
dict["te_self_fused_attn_backward"] = EncapsulateFunction(SelfFusedAttnBackward); dict["te_self_fused_attn_backward"] = EncapsulateFunction(SelfFusedAttnBackward);
dict["te_cross_fused_attn_forward"] = EncapsulateFunction(CrossFusedAttnForward); dict["te_cross_fused_attn_forward"] = EncapsulateFunction(CrossFusedAttnForward);
dict["te_cross_fused_attn_backward"] = EncapsulateFunction(CrossFusedAttnBackward); dict["te_cross_fused_attn_backward"] = EncapsulateFunction(CrossFusedAttnBackward);
dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward);
dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward);
return dict; return dict;
} }
...@@ -74,6 +76,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -74,6 +76,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_self_fused_attn_bwd_workspace_sizes", &GetSelfFusedAttnBackwardWorkspaceSizes); m.def("get_self_fused_attn_bwd_workspace_sizes", &GetSelfFusedAttnBackwardWorkspaceSizes);
m.def("get_cross_fused_attn_fwd_workspace_sizes", &GetCrossFusedAttnForwardWorkspaceSizes); m.def("get_cross_fused_attn_fwd_workspace_sizes", &GetCrossFusedAttnForwardWorkspaceSizes);
m.def("get_cross_fused_attn_bwd_workspace_sizes", &GetCrossFusedAttnBackwardWorkspaceSizes); m.def("get_cross_fused_attn_bwd_workspace_sizes", &GetCrossFusedAttnBackwardWorkspaceSizes);
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
pybind11::enum_<DType>(m, "DType", pybind11::module_local()) pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte) .value("kByte", DType::kByte)
...@@ -98,7 +102,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -98,7 +102,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local()) pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local())
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD); .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD);
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend)
......
...@@ -1253,7 +1253,6 @@ pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes( ...@@ -1253,7 +1253,6 @@ pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// TODO(rewang): add bias for cross attn?
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype); auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
// FP16/BF16 doesn't use this tensor // FP16/BF16 doesn't use this tensor
...@@ -1488,5 +1487,265 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa ...@@ -1488,5 +1487,265 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
nvte_tensor_pack_destroy(&aux_input_tensors); nvte_tensor_pack_destroy(&aux_input_tensors);
} }
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input buffers from XLA
void *q = buffers[0];
void *k = buffers[1];
void *v = buffers[2];
void *bias = buffers[3];
void *q_cu_seqlens = buffers[4];
void *kv_cu_seqlens = buffers[5];
void *seed = buffers[6];
// output buffers from XLA
void *output = buffers[7];
void *softmax_aux = buffers[8];
void *rng_state = buffers[9];
void *workspace = buffers[10];
// tensor sizes
auto batch_size = descriptor.batch_size;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto num_heads = descriptor.num_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto o_tensor = TensorWrapper(output, q_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// prep RNG state
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
// auxiliary tensors (to be propagated to the backward pass later)
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
softmax_aux);
// cuDNN workspace
auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
descriptor.wkspace_dtype);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input buffers from XLA
void *q = buffers[0];
void *k = buffers[1];
void *v = buffers[2];
void *bias = buffers[3];
void *softmax_aux = buffers[4];
void *rng_state = buffers[5];
void *output = buffers[6];
void *doutput = buffers[7];
void *q_cu_seqlens = buffers[8];
void *kv_cu_seqlens = buffers[9];
// output buffers from XLA
void *dq = buffers[10];
void *dk = buffers[11];
void *dv = buffers[12];
void *dbias = buffers[13];
void *workspace = buffers[14];
// tensor sizes
auto batch_size = descriptor.batch_size;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto num_heads = descriptor.num_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dk_tensor = TensorWrapper(dk, k_shape, dtype);
auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// auxiliary tensors (propagated from the forward pass)
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias);
// cuDNN workspace
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_input_tensors);
}
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -236,6 +236,20 @@ pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes( ...@@ -236,6 +236,20 @@ pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -5,11 +5,19 @@ ...@@ -5,11 +5,19 @@
from .module import DenseGeneral, LayerNorm from .module import DenseGeneral, LayerNorm
from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from .transformer import extend_logical_axis_rules from .transformer import extend_logical_axis_rules
from .transformer import MultiHeadAttention, RelativePositionBiases from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases
from .transformer import TransformerLayer, TransformerLayerType from .transformer import TransformerLayer, TransformerLayerType
__all__ = [ __all__ = [
'DenseGeneral', 'LayerNorm', 'LayerNormDenseGeneral', 'LayerNormMLP', 'DenseGeneral',
'TransformerEngineBase', 'extend_logical_axis_rules', 'MultiHeadAttention', 'LayerNorm',
'RelativePositionBiases', 'TransformerLayer', 'TransformerLayerType', 'LayerNormDenseGeneral',
'LayerNormMLP',
'TransformerEngineBase',
'extend_logical_axis_rules',
'DotProductAttention',
'MultiHeadAttention',
'RelativePositionBiases',
'TransformerLayer',
'TransformerLayerType',
] ]
...@@ -16,6 +16,7 @@ import jax.numpy as jnp ...@@ -16,6 +16,7 @@ import jax.numpy as jnp
import numpy as np import numpy as np
from flax import linen as nn from flax import linen as nn
from flax.linen import partitioning as nn_partitioning from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import combine_masks
from jax import nn as jax_nn from jax import nn as jax_nn
from jax import random as jax_random from jax import random as jax_random
from jax import lax, vmap from jax import lax, vmap
...@@ -24,8 +25,8 @@ from jax.ad_checkpoint import checkpoint_name ...@@ -24,8 +25,8 @@ from jax.ad_checkpoint import checkpoint_name
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax from .module import LayerNorm, Softmax
from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from ..fused_attn import is_fused_attn_kernel_available from ..fused_attn import is_fused_attn_kernel_available, canonicalize_attn_mask_type
from ..fused_attn import self_fused_attn, cross_fused_attn from ..fused_attn import self_fused_attn, cross_fused_attn, fused_attn
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
from ..sharding import num_of_devices from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
...@@ -71,12 +72,12 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: ...@@ -71,12 +72,12 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
Parameters Parameters
---------- ----------
rules : Sequence[Tuple[str, Union[str, None]]] rules: Sequence[Tuple[str, Union[str, None]]]
the base Flax logical axis rules to extend. the base Flax logical axis rules to extend.
Returns Returns
------- -------
extended_rules : Sequence[Tuple[str, Union[str, None]]] extended_rules: Sequence[Tuple[str, Union[str, None]]]
the extended Flax logical axis rules. the extended Flax logical axis rules.
""" """
rules_map = {} rules_map = {}
...@@ -108,60 +109,43 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: ...@@ -108,60 +109,43 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
return tuple(extended_rules) return tuple(extended_rules)
def _merge_mask(func, *masks: Optional[Array]): class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
masks = [m for m in masks if m is not None] attention_dropout: float = 0.
if not masks: attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
return None attn_bias_type: Optional[AttnBiasType] = None
assert all(map(lambda x: x.ndim == masks[0].ndim, dtype: DType = jnp.float32
masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}') float32_logits: bool = False
mask, *other_masks = masks scale_factor: Optional[float] = None
for other_mask in other_masks: transpose_batch_sequence: bool = True
mask = func(mask, other_mask)
return mask
def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
"""Combine attention masks."""
func = jnp.logical_and
return _merge_mask(func, *masks).astype(dtype)
def combine_biases(*masks: Optional[Array]):
"""Combine attention biases."""
def func(a, b):
return a + b
return _merge_mask(func, *masks)
def core_attention(query: Array, @nn.compact
def __call__(self,
query: Array,
key: Array, key: Array,
value: Array, value: Array,
scale_factor: float,
transpose_batch_sequence: bool,
softmax_type: SoftmaxType = SoftmaxType.SCALED,
mask: Optional[Array] = None, mask: Optional[Array] = None,
bias: Optional[Array] = None, bias: Optional[Array] = None,
*,
dropout_rng: Optional[PRNGKey] = None, dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0., deterministic: bool = False) -> Array:
deterministic: bool = False,
dtype: DType = jnp.float32,
float32_logits: bool = False):
"""Core attention"""
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
batch_dim = 1 if transpose_batch_sequence else 0 batch_dim = 1 if self.transpose_batch_sequence else 0
assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], ( assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
'q, k, v batch dims must match.') 'q, k, v batch dims must match.')
sequence_dim = 0 if transpose_batch_sequence else 1 sequence_dim = 0 if self.transpose_batch_sequence else 1
assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.' assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.' assert key.shape[-2] == value.shape[-2], 'k, v num_attention_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.' assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.'
if float32_logits: if self.scale_factor is None:
scale_factor = 1.0 / sqrt(query.shape[-1])
else:
scale_factor = self.scale_factor
del self.scale_factor
if self.float32_logits:
query = query.astype(jnp.float32) query = query.astype(jnp.float32)
key = key.astype(jnp.float32) key = key.astype(jnp.float32)
h_q, h_kv = query.shape[-2], key.shape[-2] h_q, h_kv = query.shape[-2], key.shape[-2]
# The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv. # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
# Therefore, we have to maintain two code paths. # Therefore, we have to maintain two code paths.
...@@ -172,7 +156,7 @@ def core_attention(query: Array, ...@@ -172,7 +156,7 @@ def core_attention(query: Array,
group_size = h_q // h_kv group_size = h_q // h_kv
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1])) grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
if transpose_batch_sequence: if self.transpose_batch_sequence:
if is_gqa: if is_gqa:
attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key) attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
else: else:
...@@ -193,30 +177,47 @@ def core_attention(query: Array, ...@@ -193,30 +177,47 @@ def core_attention(query: Array,
attn_weights = with_sharding_constraint_by_logical_axes( attn_weights = with_sharding_constraint_by_logical_axes(
attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES)) attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))
# When a bias is present, the computation is performed as Softmax(attn_weights * scale + bias). # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
# In this case, the scale can not fused into the Softmax module. # In this case, the scale can not fused into the Softmax module.
if bias is not None: if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
attn_weights = attn_weights * scale_factor attn_weights = attn_weights * scale_factor
fused_scale_factor = 1. fused_scale_factor = 1.
else: else:
# If no bias, the scale can be fused into Softmax module # If not post_scale_bias, the scale can be fused into Softmax module
fused_scale_factor = scale_factor fused_scale_factor = scale_factor
if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
attn_weights += bias
def convert_to_softmax_type(attn_mask_type, mask):
"""Convert the attn_mask_type to SoftmaxType"""
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]:
if mask is not None:
return SoftmaxType.SCALED_MASKED
return SoftmaxType.SCALED
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type = {'causal', 'padding'}")
softmax_type = convert_to_softmax_type(self.attn_mask_type, mask)
attn_weights = Softmax(softmax_type=softmax_type, attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype) scale_factor=fused_scale_factor)(attn_weights, mask,
bias).astype(self.dtype)
if is_gqa: if is_gqa:
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
if not deterministic and dropout_rate > 0.: if not deterministic and self.attention_dropout > 0.:
keep_prob = 1.0 - dropout_rate keep_prob = 1.0 - self.attention_dropout
dropout_shape = list(attn_weights.shape) dropout_shape = list(attn_weights.shape)
# TODO(rewang): add attention dropout broadcast dimension arguments for users # TODO(rewang): add attention dropout broadcast dimension arguments for users
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape) keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) multiplier = (keep.astype(attn_weights.dtype) /
jnp.asarray(keep_prob, dtype=self.dtype))
attn_weights = attn_weights * multiplier attn_weights = attn_weights * multiplier
if transpose_batch_sequence: if self.transpose_batch_sequence:
if is_gqa: if is_gqa:
return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape) return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value) return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value)
...@@ -226,6 +227,320 @@ def core_attention(query: Array, ...@@ -226,6 +227,320 @@ def core_attention(query: Array,
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attention_dropout: float = 0.
attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
attn_bias_type: Optional[AttnBiasType] = None
dtype: DType = jnp.float32
qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = False
@nn.compact
def __call__(self,
query: Array,
key: Array,
value: Array,
mask: Optional[Array] = None,
bias: Optional[Array] = None,
*,
dropout_rng: Optional[PRNGKey] = None,
deterministic: bool = False) -> Array:
seed = None
if dropout_rng is not None:
seed = jax.random.split(dropout_rng, num_of_devices())
if self.scale_factor is None:
scale_factor = 1.0 / sqrt(query.shape[-1])
else:
scale_factor = self.scale_factor
del self.scale_factor
if self.qkv_layout == QKVLayout.BS3HD:
"""qkvpacked format, treat
query: qkvpacked tensor, shape = [..., 3, h, d]
key: ignore
value: ignore
"""
qkv_packed = query
if self.transpose_batch_sequence:
qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
x = self_fused_attn(qkv_packed,
bias,
mask,
seed,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic)
elif self.qkv_layout == QKVLayout.BSHD_BS2HD:
"""kvpacked format, treat
query: query tensor, shape = [..., h, d]
key: kvpacked tensor, shape = [..., 2, h, d]
value: ignore
"""
kv_packed = key
if self.transpose_batch_sequence:
query = query.transpose([1, 0, 2, 3])
kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
x = cross_fused_attn(query,
kv_packed,
bias,
mask,
seed,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic)
elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
if self.transpose_batch_sequence:
query = query.transpose([1, 0, 2, 3])
key = key.transpose([1, 0, 2, 3])
value = value.transpose([1, 0, 2, 3])
x = fused_attn(query,
key,
value,
bias,
mask,
seed,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic)
else:
raise ValueError(f"Unsupported {self.qkv_layout=}.")
if self.transpose_batch_sequence:
x = x.transpose([1, 0, 2, 3])
return x
class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
r"""
Dot Product Attention (DPA). Allows the model to jointly attend to information from different
representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
.. note::
The DotProductAttention module supports two backends: the unfused and the fused attention
mechanisms. The unfused attention is implemented using JAX native operations, providing
broad compatibility and flexibility. In contrast, the fused attention uses `cuDNN fused
attention
<https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md>`_ for
higher performance and lower memory usage on the supported hardwares.
Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
variable:
* Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention (default).
* Set :attr:`NVTE_FUSED_ATTN=1` for fused attention. If the required cuDNN fused attention
kernel is not available on the system, a warning will be issued, and the module will
automatically fall back to the unfused backend.
Parameters
----------
head_dim: int
The hidden dimension of each attention head.
num_attention_heads: int
The number of attention heads.
num_gqa_groups: int, default = `None`
Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
attention_dropout: float, default = 0.0
Dropout probability for the dropout op after the softmax.
attn_mask_type: str, default = 'causal'
Type of the attention mask passed into softmax operation in the self attention.
Available options: {'no_mask', 'padding', 'causal', 'causal_padding'}
Introduced in v0.10.0.
attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the self attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
When default is present, the type is automatically decided by the MHA's bias parameter.
Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used.
dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that is used
to generate Dropout masks in the core attention.
float32_logits: bool, default = False
Whether to compute attention logits in float32 for the unfused attention backend.
For fused attention backend, the accumulation is always float32 without the perf overhead.
qkv_layout: str, default = 'bshd_bshd_bshd'
Specifies the dimensional layout format for the query, key, and value tensors in __call__().
It indicates how the inputs are processed.
Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd'}. Where
* bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d].
key and value arguments in :attr:`__call__()` are ignored in this layout.
* bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked
tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored.
* bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d].
Explanation of denotations:
* b: batch size
* s: seqeuence length
* h: num_attention_heads or num_gqa_groups
* d: head dimension
scale_factor: Optional[float], default = None
Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal
to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't
need to apply scale on query, which is to set :attr:`scale_factor=1.`.
transpose_batch_sequence: bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
"""
head_dim: int
num_attention_heads: int
num_gqa_groups: Optional[int] = None
attention_dropout: float = 0.
attn_mask_type: AttnMaskType = 'causal'
attn_bias_type: AttnBiasType = None
dtype: DType = jnp.float32
dropout_rng_name: str = 'dropout'
float32_logits: bool = False
qkv_layout: str = 'bshd_bshd_bshd'
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
@nn.compact
def __call__(self,
query: Array,
key: Array,
value: Array,
mask: Optional[Array] = None,
bias: Optional[Array] = None,
*,
deterministic: bool = False) -> Array:
"""
Parameters
----------
query: jax.numpy.ndarray
The details of query tensor representation is described in :attr:`qkv_layout`.
key: jax.numpy.ndarrary
The details of kery tensor representation is described in :attr:`qkv_layout`.
value: jax.numpy.ndarrary
The details of value tensor representation is described in :attr:`qkv_layout`.
mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out the attention softmax input.
:attr:`True` means to mask out the corresponding values.
bias: jax.numpy.ndarray, default = None
A tensor used to shift attention softmax input.
*:
Below parameters are keyword only
deterministic: bool, default = False
Disable dropout layers if set to True.
Returns
-------
outputs: jax.numpy.ndarray
Output tensors.
"""
# For internal API, we use enum to maintain
if self.attn_bias_type is None:
attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
else:
attn_bias_type = AttnBiasType[self.attn_bias_type.upper()]
attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
qkv_layout = QKVLayout[self.qkv_layout.upper()]
del self.attn_bias_type, self.attn_mask_type, self.qkv_layout
if attn_bias_type == AttnBiasType.NO_BIAS:
assert bias is None
else:
assert bias is not None
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))
sequence_dim = 0 if self.transpose_batch_sequence else 1
seqlen_q = query.shape[sequence_dim]
if qkv_layout == QKVLayout.BS3HD:
seqlen_kv = seqlen_q
else:
seqlen_kv = key.shape[sequence_dim]
has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout,
attn_bias_type, attn_mask_type,
self.attention_dropout,
self.num_attention_heads,
self.num_gqa_groups, seqlen_q,
seqlen_kv, self.head_dim)
use_fused_attn = (enable_fused_attn and has_fused_attn_kernel)
if enable_fused_attn and not has_fused_attn_kernel:
warnings.warn("Fused attention is not enabled because there is no available kernel.\n"
"Fall back to the unfused attention.\n"
"Please try to update the cuDNN and TE to the latest version.\n"
f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{self.head_dim=}\n")
dropout_rng = None
if not deterministic and self.attention_dropout > 0.:
dropout_rng = self.make_rng(self.dropout_rng_name)
if self.scale_factor is None:
scale_factor = 1.0 / sqrt(self.head_dim)
else:
scale_factor = self.scale_factor
del self.scale_factor
if not use_fused_attn:
# unfused attention only supports splitted query, key, value
if qkv_layout == QKVLayout.BS3HD:
query, key, value = jnp.split(query, [1, 2], axis=-3)
query, key, value = map(functools.partial(jnp.squeeze, axis=-3),
[query, key, value])
elif qkv_layout == QKVLayout.BSHD_BS2HD:
key, value = jnp.split(key, [1], axis=-3)
key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value])
else:
assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
x = _UnfusedDotProductAttention(attention_dropout=self.attention_dropout,
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
dtype=self.dtype,
float32_logits=self.float32_logits,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence)(
query,
key,
value,
mask,
bias,
dropout_rng=dropout_rng,
deterministic=deterministic)
else:
x = _FusedDotProductAttention(
attention_dropout=self.attention_dropout,
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
dtype=self.dtype,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
qkv_layout=qkv_layout,
)(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
return x
def rotary_pos_emb(x: Array, windows: Tuple[int, int], transpose_batch_sequence: bool): def rotary_pos_emb(x: Array, windows: Tuple[int, int], transpose_batch_sequence: bool):
""" """
Rotary Positional Embedding Rotary Positional Embedding
...@@ -259,43 +574,44 @@ def rotary_pos_emb(x: Array, windows: Tuple[int, int], transpose_batch_sequence: ...@@ -259,43 +574,44 @@ def rotary_pos_emb(x: Array, windows: Tuple[int, int], transpose_batch_sequence:
return jnp.concatenate([part_1, part_2], axis=-1) return jnp.concatenate([part_1, part_2], axis=-1)
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
r""" r"""
Multi-head Attention (MHA), including Query, Multi-head Attention (MHA), including Query,
Key, Value and Output projection. Key, Value and Output projection.
.. note::
Argument :attr:`mask` will be ignored when
:attr:`attn_mask_type` is set to `"causal"`.
Parameters Parameters
---------- ----------
head_dim : int head_dim: int
The hidden dimension of each attention head. The hidden dimension of each attention head.
num_heads : int num_attention_heads: int
The number of attention heads The number of attention heads.
num_gqa_groups : int, default = `None` num_gqa_groups: int, default = `None`
Number of GQA groups. When `None` is present, it is equal to num_heads. Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
Grouped Query Attention is described in Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_. `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys. This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
dropout_rate : float, default = 0.0 attention_dropout: float, default = 0.0
Dropout probability for the dropout op during multi-head attention. Dropout probability for the dropout op after the softmax.
attn_mask_type: str, default = 'causal'
Type of the attention mask passed into softmax operation in the attention.
Available options: {'no_mask', 'padding', 'causal', 'causal_padding'}
Introduced in v0.10.0.
attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
When default is present, the type is automatically decided by the MHA's bias parameter.
Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
dropout_rng_name: str, default = 'dropout' dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that is used The key in given RNGs via flax.linen.Module.apply that is used
to generate Dropout masks in the core attention. to generate Dropout masks in the core attention.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm' layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
Indicate the type of layer normalization. Indicate the type of layer normalization.
layernorm_epsilon: float, default = 1e-6 layernorm_epsilon: float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma : bool, default = False zero_centered_gamma: bool, default = False
If set to `True`, the LayerNorm formula changes to If set to `True`, the LayerNorm formula changes to
.. math:: .. math::
...@@ -305,21 +621,20 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -305,21 +621,20 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
This parameter is only applicable for 'layernorm'. This parameter is only applicable for 'layernorm'.
kernel_init: Initializer, default = kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
Used for initializing the QKV and Output projection weights. Used for initializing the QKV and output projection weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
use_bias: bool, default = False use_bias: bool, default = False
Indicate whether or not to enable bias shifting for QKVO projections. Indicate whether or not to enable bias shifting for QKV and output projections.
If set to False, the layer will not learn additive biases. If set to False, the layer will not learn additive biases.
bias_init: Initializer, default = flax.linen.initializers.zeros bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`. Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
apply_residual_connection_post_layernorm : bool, default = False input_layernorm: bool, default = True
Indicate if apply residual connection with the output of layer normalization. If set to False, layer normalization to the input is not applied.
output_layernorm : bool, default = False return_layernorm_output: bool, default = False
Indicate if apply a layer normalization at the end of MHA. If set to True, output of layernorm is returned from the forward together with the output
attn_mask_type: {'causal', 'padding'}, default = 'causal' of the linear transformation.
Type of attention mask passed into softmax operation. Example use case: residual connection for transformer module is taken post layernorm.
Introduced in v0.10.0.
enable_rotary_pos_emb: bool, default = False enable_rotary_pos_emb: bool, default = False
Whether to enable rotary position embedding to projected query and key. Whether to enable rotary position embedding to projected query and key.
rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000) rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
...@@ -327,58 +642,101 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -327,58 +642,101 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
only used when :attr:`enable_rotary_pos_emb=True` only used when :attr:`enable_rotary_pos_emb=True`
enable_sequence_parallel: bool, default = False enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot. Whether to enable sequence parallelism to operations except dot.
num_heads: int, default = None
Deprecated. Please refer `num_attention_heads`.
dropout_rate: float, default = None
Deprecated. Please refer `attention_dropout`.
output_layernorm: bool, default = None
Deprecated. Please refer `input_layernorm`
apply_residual_connection_post_layernorm: bool, default = None
Deprecated. Please refer `return_layernorm_output`.
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype :jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
fuse_qkv: bool, default = True fuse_qkv_params: bool, default = True
If set to True, this module exposes a single fused If set to True, this module exposes a single fused
parameter for query-key-value for self-attention and key-value for parameter for query-key-value for self-attention and key-value for
cross-attention. cross-attention.
transpose_batch_sequence : bool, default = True transpose_batch_sequence: bool, default = True
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
scale_attn_logits: bool, default = False scale_attn_logits: bool, default = False
Indicate whether to scale attention logits. Indicate whether to scale attention logits.
If set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`, If set to True, :math:`\frac{Q}{\sqrt{head\_dim}*K}`,
else :math:`Q*K` else :math:`Q*K`
scaled_query_init: bool, default = `True` scaled_query_init: bool, default = True
Whether to scale WQ on initialization by :math:`\sqrt{head_dim}` Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}`
float32_logits : bool, default = False float32_logits: bool, default = False
Whether to compute attention logits in float32. Whether to compute attention logits in float32 for the unfused attention backend.
For fused attention backend, the accumulation is always float32 without the perf overhead.
fuse_qkv: bool, default = None
Deprecated. Please refer `fuse_qkv_params`
""" """
head_dim: int head_dim: int
num_heads: int num_attention_heads: int
num_gqa_groups: int | None = None num_gqa_groups: Optional[int] = None
dropout_rate: float = 0. attention_dropout: float = 0.
dropout_rng_name: str = 'dropout' dropout_rng_name: str = 'dropout'
input_layernorm: bool = True
layernorm_type: str = "layernorm" layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6 layernorm_epsilon: float = 1e-6
return_layernorm_output: bool = False
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
kernel_init: Initializer = None kernel_init: Initializer = None
use_bias: bool = False use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
attn_mask_type: str = 'causal' attn_mask_type: str = 'causal'
attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
dtype: DType = jnp.float32 dtype: DType = jnp.float32
fuse_qkv: bool = True fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False enable_sequence_parallel: bool = False
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
float32_logits: bool = False # computes logits in float32 for stability. float32_logits: bool = False
# Deprecated parameters
num_heads: Optional[int] = None
dropout_rate: Optional[float] = None
output_layernorm: Optional[bool] = None
apply_residual_connection_post_layernorm: Optional[bool] = None
fuse_qkv: Optional[bool] = None
def __post_init__(self): def __post_init__(self):
# Deal with the deprecated parameters
if self.num_heads is not None:
self.num_attention_heads = self.num_heads
warnings.warn(
f"{__class__}.num_heads is deprecated. It will be removed recently. "
f"Please uses {__class__}.num_attention_heads as the new API.", DeprecationWarning)
if self.dropout_rate is not None:
self.attention_dropout = self.dropout_rate
warnings.warn(
f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
f"Please use {__class__}.attention_dropout as the new API.", DeprecationWarning)
if self.apply_residual_connection_post_layernorm is not None:
warnings.warn(
f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
f"It will be removed recently, please use {__class__}.return_layernorm_output.",
DeprecationWarning)
if self.fuse_qkv is not None:
warnings.warn(
f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
f"Please use {__class__}.fuse_qkv_params as the new API.", DeprecationWarning)
assert self.output_layernorm is None, (
f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm.")
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal') self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
if self.num_gqa_groups is None: if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_heads self.num_gqa_groups = self.num_attention_heads
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -396,23 +754,24 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -396,23 +754,24 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
Parameters Parameters
---------- ----------
inputs_q : jax.numpy.ndarray inputs_q: jax.numpy.ndarray
Input tensor for query projection. Input tensor for query projection.
inputs_kv : jax.numpy.ndarray inputs_kv: jax.numpy.ndarray
Input tensor for key/value projection. Input tensor for key/value projection.
mask : jax.numpy.ndarray, default = None mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out self-attention softmax input. Boolean tensor used to mask out the attention softmax input.
bias : jax.numpy.ndarray, default = None :attr:`True` means mask out the corresponding values.
A tensor used to shift self-attention softmax input. bias: jax.numpy.ndarray, default = None
A tensor used to shift the attention softmax input.
* *
decode : bool,default = False decode: bool, default = False
Indicate whether to prepare and use an autoregressive cache. Indicate whether to prepare and use an autoregressive cache.
deterministic : bool,default = False deterministic: bool, default = False
Disable dropout layers if set to True. Disable dropout layers if set to True.
Returns Returns
------- -------
outputs : jax.numpy.ndarray outputs: jax.numpy.ndarray
Output tensors. Output tensors.
""" """
...@@ -450,56 +809,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -450,56 +809,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype) return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)
# TODO(rewang): make it configurable for pre_scale_bias
attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
def canonicalize_attn_mask_type(attn_mask_type):
"""
Convert the string to AttnMaskType
"""
if attn_mask_type == 'causal':
return AttnMaskType.PADDING_CAUSAL_MASK
if attn_mask_type == 'padding':
return AttnMaskType.PADDING_MASK
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type = {'causal', 'padding'}")
is_self_attn = (inputs_q is inputs_kv)
is_gqa = (self.num_heads != self.num_gqa_groups)
is_qkvpack = (is_self_attn and not is_gqa)
qkv_layout = QKVLayout.BS3HD if is_self_attn else QKVLayout.BSHD_BS2HD
attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1]
kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1]
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))
has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout,
attn_bias_type, attn_mask_type,
self.dropout_rate, self.num_heads,
self.num_gqa_groups, q_seqlen,
kv_seqlen, self.head_dim)
use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
has_fused_attn_kernel and \
enable_fused_attn
if enable_fused_attn and not use_fused_attn:
reason = ""
if decode:
reason += f"decode=False is required but got {decode}, "
if self.transpose_batch_sequence:
reason += f"transpose_batch_sequence=False is required " \
f"but got {self.transpose_batch_sequence}, "
if not self.fuse_qkv:
reason += f"fuse_qkv=True is required but got {self.fuse_qkv}, "
if not has_fused_attn_kernel:
reason += "no fused attention kernel is available, "
warnings.warn(
f"Fused attention is not enabled. Because " \
f"{reason}fall back to unfused attention.")
def generate_batch_seqlen_logical_axes(is_sharded_seq): def generate_batch_seqlen_logical_axes(is_sharded_seq):
sequence_dim = 0 if self.transpose_batch_sequence else 1 sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim batch_dim = 1 - sequence_dim
...@@ -510,24 +819,27 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -510,24 +819,27 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES
return tuple(axes) return tuple(axes)
is_self_attn = (inputs_q is inputs_kv)
is_gqa = (self.num_attention_heads != self.num_gqa_groups)
is_qkvpack = (is_self_attn and not is_gqa)
inputs_logical_axes_maybe_sp = (*generate_batch_seqlen_logical_axes( inputs_logical_axes_maybe_sp = (*generate_batch_seqlen_logical_axes(
self.enable_sequence_parallel), HIDDEN_AXES) self.enable_sequence_parallel), HIDDEN_AXES)
inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES) inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES)
inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp) inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)
residual = inputs_q if self.fuse_qkv_params:
if self.fuse_qkv:
if is_qkvpack: if is_qkvpack:
qkv_proj, ln_out = LayerNormDenseGeneral( qkv_proj, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm, enable_layernorm=self.input_layernorm,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
axis=-1, axis=-1,
features=(3, self.num_heads * self.head_dim), features=(3, self.num_attention_heads * self.head_dim),
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm, return_layernorm_output=self.return_layernorm_output,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
...@@ -540,19 +852,17 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -540,19 +852,17 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
name='qkv', name='qkv',
dtype=self.dtype)(inputs_q) dtype=self.dtype)(inputs_q)
qkv_proj = checkpoint_name(qkv_proj, 'combined_qkv_proj') qkv_proj = checkpoint_name(qkv_proj, 'combined_qkv_proj')
if not use_fused_attn: qkv_layout = QKVLayout.BS3HD
query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
else: else:
query, ln_out = LayerNormDenseGeneral( query, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm, enable_layernorm=self.input_layernorm,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
axis=-1, axis=-1,
features=self.num_heads * self.head_dim, features=self.num_attention_heads * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=(self.apply_residual_connection_post_layernorm return_layernorm_output=(self.return_layernorm_output or is_self_attn),
or is_self_attn),
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes=(W_FSDP_AXES, W_TP_AXES), kernel_axes=(W_FSDP_AXES, W_TP_AXES),
...@@ -580,8 +890,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -580,8 +890,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
name='kv', name='kv',
dtype=self.dtype)(inputs_kv) dtype=self.dtype)(inputs_kv)
kv_proj = checkpoint_name(kv_proj, 'combined_kv_proj') kv_proj = checkpoint_name(kv_proj, 'combined_kv_proj')
if not use_fused_attn: qkv_layout = QKVLayout.BSHD_BS2HD
key, value = jnp.split(kv_proj, [1], axis=-2)
else: else:
kv_projection = functools.partial( kv_projection = functools.partial(
DenseGeneral, DenseGeneral,
...@@ -594,12 +903,12 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -594,12 +903,12 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_TP_AXES,), bias_axes=(W_TP_AXES,),
dtype=self.dtype) dtype=self.dtype)
query, ln_out = LayerNormDenseGeneral( query, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm, enable_layernorm=self.input_layernorm,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
axis=-1, axis=-1,
features=self.num_heads * self.head_dim, features=self.num_attention_heads * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=True, return_layernorm_output=True,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
...@@ -620,44 +929,31 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -620,44 +929,31 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv) key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv) value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
query = checkpoint_name(query, 'query_proj')
if self.apply_residual_connection_post_layernorm: key = checkpoint_name(key, 'key_proj')
assert ln_out is not None value = checkpoint_name(value, 'value_proj')
residual = ln_out qkv_layout = QKVLayout.BSHD_BSHD_BSHD
if self.enable_rotary_pos_emb: if self.enable_rotary_pos_emb:
if self.fuse_qkv and use_fused_attn: if qkv_layout == QKVLayout.BS3HD:
if is_qkvpack:
query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2) query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
else: elif qkv_layout == QKVLayout.BSHD_BS2HD:
key, value = jnp.split(kv_proj, [1], axis=-2) key, value = jnp.split(kv_proj, [1], axis=-2)
else:
assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
query = rotary_pos_emb(query, self.rotary_pos_emb_windows, query = rotary_pos_emb(query, self.rotary_pos_emb_windows,
self.transpose_batch_sequence) self.transpose_batch_sequence)
key = rotary_pos_emb(key, self.rotary_pos_emb_windows, self.transpose_batch_sequence) key = rotary_pos_emb(key, self.rotary_pos_emb_windows, self.transpose_batch_sequence)
qkv_layout = QKVLayout.BSHD_BSHD_BSHD
if use_fused_attn: if qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
if is_qkvpack: query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
qkv_proj = jnp.concatenate([query, key, value], axis=-2)
else:
kv_proj = jnp.concatenate([key, value], axis=-2)
if not use_fused_attn:
query = checkpoint_name(query, 'query_proj')
key = checkpoint_name(key, 'key_proj')
value = checkpoint_name(value, 'value_proj')
query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim)) key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
qkv_sharding_constraint = \
(SEQLEN_AXES, BATCH_AXES, HEAD_AXES, HIDDEN_AXES) \
if self.transpose_batch_sequence \
else (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
if decode: if decode:
assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
is_initialized = self.has_variable('cache', 'cached_key') is_initialized = self.has_variable('cache', 'cached_key')
cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype) cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype)
...@@ -667,12 +963,12 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -667,12 +963,12 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
lambda: jnp.array(0, dtype=jnp.int32)) lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized: if is_initialized:
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
length, batch, num_heads, head_dim = cached_key.value.shape length, batch, num_attention_heads, head_dim = cached_key.value.shape
expected_shape = (1, batch, num_heads, head_dim) expected_shape = (1, batch, num_attention_heads, head_dim)
one_hot_indices_shape = (length, 1, 1, 1) one_hot_indices_shape = (length, 1, 1, 1)
else: else:
batch, length, num_heads, head_dim = cached_key.value.shape batch, length, num_attention_heads, head_dim = cached_key.value.shape
expected_shape = (batch, 1, num_heads, head_dim) expected_shape = (batch, 1, num_attention_heads, head_dim)
one_hot_indices_shape = (1, length, 1, 1) one_hot_indices_shape = (1, length, 1, 1)
# Sanity shape check of cached key against input query. # Sanity shape check of cached key against input query.
...@@ -694,100 +990,58 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -694,100 +990,58 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length))) mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length)))
if bias is not None: if bias is not None:
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim,
in_axes=(None, 0, None, None))
bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0), bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
jnp.reshape(cur_index, (-1)), 1, -2) jnp.reshape(cur_index, (-1)), 1, -2)
scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0 scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
dropout_rng = None LEADING_AXES = (BATCH_AXES, SEQLEN_AXES)
if not deterministic and self.dropout_rate > 0.: if self.transpose_batch_sequence:
dropout_rng = self.make_rng(self.dropout_rng_name) LEADING_AXES = (SEQLEN_AXES, BATCH_AXES)
if use_fused_attn: if qkv_layout == QKVLayout.BS3HD:
assert mask is not None and mask.ndim == 4 # (b, 1, s_q, s_kv) qkv_proj = qkv_proj.reshape(*qkv_proj.shape[:2], 3, self.num_attention_heads,
assert not self.transpose_batch_sequence self.head_dim)
qkv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
seed = None qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, qkv_sharding_constraint)
if dropout_rng is not None: dpa_args = [qkv_proj, None, None]
seed = jax.random.split(dropout_rng, num_of_devices()) elif qkv_layout == QKVLayout.BSHD_BS2HD:
# ensure the old key never used query = query.reshape(*query.shape[:2], self.num_attention_heads, self.head_dim)
del dropout_rng kv_proj = kv_proj.reshape(*kv_proj.shape[:2], 2, self.num_gqa_groups, self.head_dim)
q_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
if is_qkvpack: kv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
HIDDEN_AXES)
qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj,
qkv_sharding_constraint)
x = self_fused_attn(qkv_proj,
bias,
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor,
dropout_probability=self.dropout_rate,
is_training=not deterministic)
else:
assert bias is None
query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim))
kv_proj = kv_proj.reshape((*kv_proj.shape[:-1], self.num_gqa_groups, self.head_dim))
q_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
kv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
HIDDEN_AXES)
query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint) query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint)
kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint) kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint)
dpa_args = [query, kv_proj, None]
x = cross_fused_attn(query,
kv_proj,
bias,
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor,
dropout_probability=self.dropout_rate,
is_training=not deterministic)
else: else:
assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
qkv_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
dpa_args = [query, key, value]
def convert_to_softmax_type(attn_mask_type, mask): x = DotProductAttention(head_dim=self.head_dim,
""" num_attention_heads=self.num_attention_heads,
Convert the string to SoftmaxType num_gqa_groups=self.num_gqa_groups,
""" attn_mask_type=self.attn_mask_type,
if attn_mask_type == 'causal': attn_bias_type=self.attn_bias_type,
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED attention_dropout=self.attention_dropout,
if attn_mask_type == 'padding':
if mask is not None:
return SoftmaxType.SCALED_MASKED
return SoftmaxType.SCALED
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type = {'causal', 'padding'}")
softmax_type = convert_to_softmax_type(self.attn_mask_type, mask)
x = core_attention(query,
key,
value,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
softmax_type=softmax_type,
mask=mask,
bias=bias,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
deterministic=deterministic,
dtype=self.dtype, dtype=self.dtype,
float32_logits=self.float32_logits) dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_logits,
x = checkpoint_name(x, 'context') qkv_layout=qkv_layout.name,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence)(
*dpa_args, mask, bias, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
attn_context_sharding_constraint = \ attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES)
(SEQLEN_AXES, BATCH_AXES, HIDDEN_TP_AXES) \
if self.transpose_batch_sequence \
else (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint) x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint)
out = DenseGeneral(features=inputs_q.shape[-1], out = DenseGeneral(features=inputs_q.shape[-1],
...@@ -801,7 +1055,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -801,7 +1055,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
dtype=self.dtype, dtype=self.dtype,
name='out')(x) name='out')(x)
out = checkpoint_name(out, 'out_proj') out = checkpoint_name(out, 'out_proj')
return out, residual
return out, ln_out
class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-methods class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-methods
...@@ -810,21 +1065,21 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-met ...@@ -810,21 +1065,21 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-met
Parameters Parameters
---------- ----------
num_buckets : int num_buckets: int
The number of buckets to bucket distances between key and query positions into. The number of buckets to bucket distances between key and query positions into.
max_distance : int max_distance: int
The maximum distance before everything is lumped into the last The maximum distance before everything is lumped into the last
distance bucket. distance bucket.
num_attention_heads : int num_attention_heads: int
Number of attention heads in the transformer layer. Number of attention heads in the transformer layer.
embedding_init : Initializer, default = flax.linen.linear.default_embed_init embedding_init: Initializer, default = flax.linen.linear.default_embed_init
Used for initializing relative embedding tables. Used for initializing relative embedding tables.
embedding_axes : Tuple[str, ...], default = ('heads', 'relpos_buckets') embedding_axes: Tuple[str, ...], default = ('heads', 'relpos_buckets')
The name of axes used to shard embedding attention bias with a corresponding mesh. The name of axes used to shard embedding attention bias with a corresponding mesh.
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
""" """
num_buckets: int num_buckets: int
...@@ -841,11 +1096,11 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-met ...@@ -841,11 +1096,11 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-met
Parameters Parameters
---------- ----------
q_seqlen : int q_seqlen: int
The sequence length of query. The sequence length of query.
k_seqlen : int k_seqlen: int
The sequence length of key. The sequence length of key.
bidirectional : bool, default = True bidirectional: bool, default = True
Indicate whether to allow positive memory-query relative position Indicate whether to allow positive memory-query relative position
embeddings. embeddings.
...@@ -917,11 +1172,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -917,11 +1172,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
an attention block and a feedforward network (MLP). an attention block and a feedforward network (MLP).
This standard layer is based on the paper “Attention Is All You Need”. This standard layer is based on the paper “Attention Is All You Need”.
.. note::
Argument :attr:`attention_mask` will be ignored when
:attr:`self_attn_mask_type` is set to `"causal"`.
Parameters Parameters
---------- ----------
hidden_size: int, default = 512 hidden_size: int, default = 512
...@@ -930,7 +1180,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -930,7 +1180,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Intermediate size to which input samples are projected. Intermediate size to which input samples are projected.
num_attention_heads: int, default = 8 num_attention_heads: int, default = 8
Number of attention heads in the transformer layer. Number of attention heads in the transformer layer.
num_gqa_groups : int, default = `None` num_gqa_groups: int, default = `None`
Number of GQA groups. When `None` is present, it is equal to num_attention_heads. Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
Grouped Query Attention is described in Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_. `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
...@@ -938,11 +1188,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -938,11 +1188,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
GQA-1 is equivalent to Multi-Query Attention GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm' layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
Indicate the type of layer normalization. Indicate the type of layer normalization.
layernorm_epsilon: float, default = 1e-6 layernorm_epsilon: float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma : bool, default = False zero_centered_gamma: bool, default = False
If set to `True`, the LayerNorm formula changes to If set to `True`, the LayerNorm formula changes to
.. math:: .. math::
...@@ -989,14 +1239,21 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -989,14 +1239,21 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
after the final dropout-add. default behavior is to apply layer after the final dropout-add. default behavior is to apply layer
normalization on the input side, before the QKV transformation. normalization on the input side, before the QKV transformation.
float32_attention_logits: bool, default = False float32_attention_logits: bool, default = False
If set to True, attention logits are executed in jax.numpy.float32. Whether to compute attention logits in float32 for the unfused attention backend.
For fused attention backend, the accumulation is always float32 without the perf overhead.
layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
If set to TransformerLayerType.DECODER, an additional cross-attention block If set to TransformerLayerType.DECODER, an additional cross-attention block
is added after self-attention.this can be used for structures like `T5` is added after self-attention.this can be used for structures like `T5`
Transformer in conjunction with the TransformerLayerType.ENCODER option. Transformer in conjunction with the TransformerLayerType.ENCODER option.
self_attn_mask_type: {'causal', 'padding'}, default = 'causal' self_attn_mask_type: str, default = 'causal'
Type of attention mask passed into softmax operation. Type of the attention mask passed into softmax operation in the self attention.
Available options: {'no_mask', 'padding', 'causal', 'causal_padding'}
Introduced in v0.10.0. Introduced in v0.10.0.
self_attn_bias_type: Optional[str], default = None
Type of the attention bias passed into the self attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
When default is present, the type is automatically decided by the MHA's bias parameter.
Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
enable_relative_embedding: bool, default = True enable_relative_embedding: bool, default = True
Whether to enable relative embedding as shifting of attention logits. Whether to enable relative embedding as shifting of attention logits.
relative_embedding: flax.linen.Module, default = None relative_embedding: flax.linen.Module, default = None
...@@ -1017,7 +1274,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1017,7 +1274,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype :jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
drop_path: float, default = 0.0 drop_path: float, default = 0.0
When > 0.0, applies stochastic depth per sample in the main When > 0.0, applies stochastic depth per sample in the main
...@@ -1026,7 +1283,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1026,7 +1283,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
If set to True, `TransformerLayer` module exposes a single fused If set to True, `TransformerLayer` module exposes a single fused
parameter for query-key-value for self-attention and key-value for parameter for query-key-value for self-attention and key-value for
cross-attention. cross-attention.
transpose_batch_sequence : bool, default = False transpose_batch_sequence: bool, default = False
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
...@@ -1041,7 +1298,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1041,7 +1298,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
hidden_size: int = 512 hidden_size: int = 512
mlp_hidden_size: int = 2048 mlp_hidden_size: int = 2048
num_attention_heads: int = 8 num_attention_heads: int = 8
num_gqa_groups: int | None = None num_gqa_groups: Optional[int] = None
layernorm_type: str = 'layernorm' layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6 layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
...@@ -1061,6 +1318,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1061,6 +1318,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
float32_attention_logits: bool = False float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = 'causal' self_attn_mask_type: str = 'causal'
self_attn_bias_type: Optional[str] = None
enable_relative_embedding: bool = True enable_relative_embedding: bool = True
relative_embedding: nn.Module = None relative_embedding: nn.Module = None
enable_rotary_pos_emb: bool = False enable_rotary_pos_emb: bool = False
...@@ -1097,29 +1355,29 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1097,29 +1355,29 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Parameters Parameters
---------- ----------
inputs : jax.numpy.ndarray inputs: jax.numpy.ndarray
Input tensor. Input tensor.
encoded : jax.numpy.ndarray, default = None encoded: jax.numpy.ndarray, default = None
Output tensors of the encoder block to be fed into the decoder block if using Output tensors of the encoder block to be fed into the decoder block if using
:attr:`layer_type=TransformerLayerType.DECODER`. :attr:`layer_type=TransformerLayerType.DECODER`.
attention_mask : jax.numpy.ndarray, default = None attention_mask : jax.numpy.ndarray, default = None
Boolean tensor used to mask out self-attention softmax input. Boolean tensor used to mask out self-attention softmax input.
encoder_decoder_mask : jax.numpy.ndarray, default = None encoder_decoder_mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out cross-attention softmax input when Boolean tensor used to mask out cross-attention softmax input when
:attr:`layer_type=TransformerLayerType.DECODER`. :attr:`layer_type=TransformerLayerType.DECODER`.
deterministic: bool, default = False deterministic: bool, default = False
Disable dropout layers if set to True. Disable dropout layers if set to True.
decode: bool,default = False decode: bool, default = False
Indicate whether to prepare and use an autoregressive cache Indicate whether to prepare and use an autoregressive cache
in Multi-head attention (MHA). in Multi-head attention (MHA).
max_decode_length : bool, default = None max_decode_length: bool, default = None
The maximum length to generate relative embedding biases when The maximum length to generate relative embedding biases when
:attr:`layer_type=TransformerLayerType.DECODER` and :attr:`layer_type=TransformerLayerType.DECODER` and
:attr:`enable_relative_embedding=True`. :attr:`enable_relative_embedding=True`.
Returns Returns
------- -------
outputs : jax.numpy.ndarray outputs: jax.numpy.ndarray
Output tensors. Output tensors.
""" """
assert self.layer_type in TransformerLayerType, \ assert self.layer_type in TransformerLayerType, \
...@@ -1184,14 +1442,15 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1184,14 +1442,15 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)) inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
# [batch, length, emb_dim] -> [batch, length, emb_dim] # [batch, length, emb_dim] -> [batch, length, emb_dim]
x, residual = MultiHeadAttention( residual = inputs
num_heads=self.num_attention_heads, x, ln_out = MultiHeadAttention(
num_attention_heads=self.num_attention_heads,
dtype=self.dtype, dtype=self.dtype,
head_dim=head_dim, head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups, num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel, enable_sequence_parallel=self.enable_sequence_parallel,
dropout_rate=self.attention_dropout, attention_dropout=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name, dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_attention_logits, float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
...@@ -1199,12 +1458,13 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1199,12 +1458,13 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon, layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm, return_layernorm_output=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm, input_layernorm=not self.output_layernorm,
attn_mask_type=self.self_attn_mask_type, attn_mask_type=self.self_attn_mask_type,
attn_bias_type=self.self_attn_bias_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb, enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows, rotary_pos_emb_windows=self.rotary_pos_emb_windows,
fuse_qkv=self.fuse_qkv_params, fuse_qkv_params=self.fuse_qkv_params,
kernel_init=self.mha_kernel_init, kernel_init=self.mha_kernel_init,
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
...@@ -1236,6 +1496,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1236,6 +1496,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
x = nn.Dropout(rate=self.drop_path, x = nn.Dropout(rate=self.drop_path,
broadcast_dims=drop_path_shape, broadcast_dims=drop_path_shape,
rng_collection=self.dropout_rng_name)(x, deterministic=deterministic) rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
if self.apply_residual_connection_post_layernorm:
assert ln_out is not None
residual = ln_out
x = x + residual x = x + residual
mlp_input = x mlp_input = x
...@@ -1246,28 +1511,29 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1246,28 +1511,29 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
x = with_sharding_constraint_by_logical_axes( x = with_sharding_constraint_by_logical_axes(
x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)) x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
y, residual = MultiHeadAttention( residual = x
num_heads=self.num_attention_heads, y, ln_out = MultiHeadAttention(
num_attention_heads=self.num_attention_heads,
dtype=self.dtype, dtype=self.dtype,
head_dim=head_dim, head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups, num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel, enable_sequence_parallel=self.enable_sequence_parallel,
dropout_rate=self.attention_dropout, attention_dropout=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name, dropout_rng_name=self.dropout_rng_name,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon, layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
apply_residual_connection_post_layernorm=self. return_layernorm_output=self.apply_residual_connection_post_layernorm,
apply_residual_connection_post_layernorm, input_layernorm=True, # Must do LayerNorm before MHA.
output_layernorm=False, # Must do LayerNorm before MHA.
attn_mask_type='padding', attn_mask_type='padding',
attn_bias_type='no_bias',
enable_rotary_pos_emb=self.enable_rotary_pos_emb, enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows, rotary_pos_emb_windows=self.rotary_pos_emb_windows,
float32_logits=self.float32_attention_logits, float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init, scaled_query_init=self.scaled_query_init,
fuse_qkv=self.fuse_qkv_params, fuse_qkv_params=self.fuse_qkv_params,
kernel_init=self.mha_kernel_init, kernel_init=self.mha_kernel_init,
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
...@@ -1282,6 +1548,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1282,6 +1548,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)) residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
y = hidden_dropout(y, deterministic) y = hidden_dropout(y, deterministic)
if self.apply_residual_connection_post_layernorm:
assert ln_out is not None
residual = ln_out
mlp_input = y + residual mlp_input = y + residual
mlp_input = with_sharding_constraint_by_logical_axes( mlp_input = with_sharding_constraint_by_logical_axes(
...@@ -1342,6 +1613,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1342,6 +1613,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_NO_SHARD_AXES,), bias_axes=(W_NO_SHARD_AXES,),
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype, dtype=self.dtype,
name="output_layer_norm")(z) name="output_layernorm")(z)
return z return z
...@@ -16,6 +16,7 @@ from transformer_engine_jax import NVTE_QKV_Layout ...@@ -16,6 +16,7 @@ from transformer_engine_jax import NVTE_QKV_Layout
from .cpp_extensions import FusedAttnHelper from .cpp_extensions import FusedAttnHelper
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
from .cpp_extensions import fused_attn_fwd, fused_attn_bwd
class AttnBiasType(Enum): class AttnBiasType(Enum):
...@@ -37,6 +38,21 @@ class QKVLayout(Enum): ...@@ -37,6 +38,21 @@ class QKVLayout(Enum):
"""QKV layout""" """QKV layout"""
BS3HD = NVTE_QKV_Layout.NVTE_BS3HD BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD
BSHD_BSHD_BSHD = NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD
def canonicalize_attn_mask_type(attn_mask_type: str):
"""Convert string attn_mask_type to AttnMaskType
TE-JAX currently fall back to the padding version kernels for the libraries integration.
The overhead between padding and non-padding version should be small.
However, we will lease this limitation in the near feature.
"""
if attn_mask_type in ['causal', 'padding_causal']:
return AttnMaskType.PADDING_CAUSAL_MASK
if attn_mask_type in ['no_mask', 'padding']:
return AttnMaskType.PADDING_MASK
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type={'no_mask', 'padding', 'causal', 'padding_causal'}")
def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type, def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type,
...@@ -83,6 +99,10 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda ...@@ -83,6 +99,10 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda
seed: jnp.ndarray, attn_bias_type: AttnBiasType, seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float, attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool): dropout_probability: float, is_training: bool):
if mask is None:
batch, seqlen, *_ = qkv.shape
actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32)
else:
mask = jnp.logical_not(mask) mask = jnp.logical_not(mask)
actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
output, softmax_aux, rng_state = self_fused_attn_fwd(qkv, output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
...@@ -159,13 +179,18 @@ def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: ...@@ -159,13 +179,18 @@ def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask:
def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type, def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training): scaling_factor, dropout_probability, is_training):
if mask is None:
batch, s_q, *_ = q.shape
s_kv = kv.shape[1]
q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
else:
mask = jnp.logical_not(mask) mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,) kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else: else:
# When mask is padding + causal, the actual seqlen is not the last row, use max to find it # When mask is causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2)) kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = cross_fused_attn_fwd(q, output, softmax_aux, rng_state = cross_fused_attn_fwd(q,
...@@ -179,7 +204,9 @@ def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mas ...@@ -179,7 +204,9 @@ def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mas
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
output = checkpoint_name(output, 'context')
softmax_aux = checkpoint_name(softmax_aux, 'context')
rng_state = checkpoint_name(rng_state, 'context')
return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen) return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen)
...@@ -209,3 +236,100 @@ def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, d ...@@ -209,3 +236,100 @@ def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, d
_cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule) _cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule)
def fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Dot product attention with the seperated query, key, value
"""
output = _fused_attn(q,
k,
v,
bias,
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10))
def _fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray,
mask: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float, dropout_probability: float,
is_training: bool):
output, _ = _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training)
return output
def _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
if mask is None:
batch, s_q, *_ = q.shape
s_kv = k.shape[1]
q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
else:
mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else:
# When mask is causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = fused_attn_fwd(q,
k,
v,
bias,
q_actual_seqlen,
kv_actual_seqlen,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output = checkpoint_name(output, 'context')
softmax_aux = checkpoint_name(softmax_aux, 'context')
rng_state = checkpoint_name(rng_state, 'context')
return output, (q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen,
kv_actual_seqlen)
def _fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz):
q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
grad_q, grad_k, grad_v, grad_bias = fused_attn_bwd(q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
dz,
q_actual_seqlen,
kv_actual_seqlen,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None
return grad_q, grad_k, grad_v, grad_bias, None, None
_fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule)
...@@ -4,5 +4,6 @@ ...@@ -4,5 +4,6 @@
"""Praxis related Modules""" """Praxis related Modules"""
from .module import FusedSoftmax, LayerNorm from .module import FusedSoftmax, LayerNorm
from .module import LayerNormLinear, LayerNormMLP, Linear, TransformerEngineBaseLayer from .module import LayerNormLinear, LayerNormMLP, Linear, TransformerEngineBaseLayer
from .transformer import MultiHeadAttention, RelativePositionBiases, TransformerLayer from .transformer import DotProductAttention, MultiHeadAttention
from .transformer import RelativePositionBiases, TransformerLayer
from ..flax.transformer import TransformerLayerType from ..flax.transformer import TransformerLayerType
...@@ -6,6 +6,7 @@ Praxis Modules related Transformer ...@@ -6,6 +6,7 @@ Praxis Modules related Transformer
""" """
from functools import partial from functools import partial
from typing import Optional, Sequence, Tuple from typing import Optional, Sequence, Tuple
import warnings
from praxis import pax_fiddle from praxis import pax_fiddle
from praxis.base_layer import WeightInit from praxis.base_layer import WeightInit
...@@ -13,9 +14,11 @@ from praxis.pytypes import JTensor ...@@ -13,9 +14,11 @@ from praxis.pytypes import JTensor
from .module import TransformerEngineBaseLayer from .module import TransformerEngineBaseLayer
from ..flax.transformer import TransformerLayerType from ..flax.transformer import TransformerLayerType
from ..flax.transformer import DotProductAttention as flax_DotProductAttention
from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention
from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases
from ..flax.transformer import TransformerLayer as flax_TransformerLayer from ..flax.transformer import TransformerLayer as flax_TransformerLayer
from ..fused_attn import AttnBiasType, AttnMaskType
class RelativePositionBiases(TransformerEngineBaseLayer): class RelativePositionBiases(TransformerEngineBaseLayer):
...@@ -59,30 +62,117 @@ class RelativePositionBiases(TransformerEngineBaseLayer): ...@@ -59,30 +62,117 @@ class RelativePositionBiases(TransformerEngineBaseLayer):
return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional) return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional)
class DotProductAttention(TransformerEngineBaseLayer):
"""DotProductAttention"""
head_dim: int = 0
num_attention_heads: int = 0
num_gqa_groups: Optional[int] = None
attention_dropout: float = 0.
attn_mask_type: AttnMaskType = 'causal'
attn_bias_type: AttnBiasType = None
dropout_rng_name: str = 'dropout'
float32_logits: bool = False
qkv_layout: str = 'bshd_bshd_bshd'
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
def setup(self) -> None:
"""setup"""
super().setup()
assert self.head_dim > 0, f'{self.head_dim=}'
assert self.num_attention_heads > 0, f'{self.num_attention_heads=}'
dpa_cls = partial(flax_DotProductAttention,
head_dim=self.head_dim,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
attention_dropout=self.attention_dropout,
dtype=self.dtype,
dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_logits,
qkv_layout=self.qkv_layout,
scale_factor=self.scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence)
self.create_layer("dot_product_attention", dpa_cls)
def __call__(self,
query: JTensor,
key: JTensor,
value: JTensor,
mask: Optional[JTensor] = None,
bias: Optional[JTensor] = None,
*,
deterministic: bool = False) -> JTensor:
"""__call__"""
return self.dot_product_attention(query,
key,
value,
mask,
bias,
deterministic=deterministic)
class MultiHeadAttention(TransformerEngineBaseLayer): class MultiHeadAttention(TransformerEngineBaseLayer):
"""MultiHeadAttention""" """MultiHeadAttention"""
head_dim: int = 64 head_dim: int = 0
num_heads: int = 16 num_attention_heads: int = 0
num_gqa_groups: int | None = None num_gqa_groups: Optional[int] = None
dropout_rate: float = 0. attention_dropout: float = 0.
dropout_rng_name: str = 'dropout' dropout_rng_name: str = 'dropout'
input_layernorm: bool = True
layernorm_type: str = "layernorm" layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6 layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
return_layernorm_output: bool = False
use_bias: bool = False use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0) bias_init: WeightInit = WeightInit.Constant(0.0)
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
attn_mask_type: str = 'causal' attn_mask_type: str = 'causal'
fuse_qkv: bool = True attn_bias_type: Optional[str] = None
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False enable_sequence_parallel: bool = False
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
float32_logits: bool = False float32_logits: bool = False
# Deprecated parameters
num_heads: Optional[int] = None
dropout_rate: Optional[float] = None
output_layernorm: Optional[bool] = None
apply_residual_connection_post_layernorm: Optional[bool] = None
fuse_qkv: Optional[bool] = None
def __post_init__(self): def __post_init__(self):
# Deal with the deprecated parameters
if self.num_heads is not None:
self.num_attention_heads = self.num_heads
warnings.warn(
f"{__class__}.num_heads is deprecated. It will be removed recently. "
f"Please uses {__class__}.num_attention_heads as the new API.", DeprecationWarning)
if self.dropout_rate is not None:
self.attention_dropout = self.dropout_rate
warnings.warn(
f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
f"Please use {__class__}.attention_dropout as the new API.", DeprecationWarning)
if self.apply_residual_connection_post_layernorm is not None:
warnings.warn(
f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
f"It will be removed recently, please use {__class__}.return_layernorm_output.",
DeprecationWarning)
if self.fuse_qkv is not None:
warnings.warn(
f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
f"Please use {__class__}.fuse_qkv_params as the new API.", DeprecationWarning)
assert self.output_layernorm is None, (
f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm.")
if self.num_gqa_groups is None: if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_heads self.num_gqa_groups = self.num_heads
super().__post_init__() super().__post_init__()
...@@ -91,24 +181,28 @@ class MultiHeadAttention(TransformerEngineBaseLayer): ...@@ -91,24 +181,28 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
"""setup""" """setup"""
super().setup() super().setup()
assert self.head_dim > 0, f'{self.head_dim=}'
assert self.num_attention_heads > 0, f'{self.num_attention_heads=}'
mha_cls = partial( mha_cls = partial(
flax_MultiHeadAttention, flax_MultiHeadAttention,
dtype=self.dtype, dtype=self.dtype,
head_dim=self.head_dim, head_dim=self.head_dim,
num_heads=self.num_heads, num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups, num_gqa_groups=self.num_gqa_groups,
dropout_rate=self.dropout_rate, attention_dropout=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name, dropout_rng_name=self.dropout_rng_name,
input_layernorm=self.input_layernorm,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon, layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
return_layernorm_output=self.return_layernorm_output,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init), kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm,
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
fuse_qkv=self.fuse_qkv, attn_bias_type=self.attn_bias_type,
fuse_qkv_params=self.fuse_qkv_params,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel, enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
...@@ -140,7 +234,7 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -140,7 +234,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
hidden_size: int = 512 hidden_size: int = 512
mlp_hidden_size: int = 2048 mlp_hidden_size: int = 2048
num_attention_heads: int = 8 num_attention_heads: int = 8
num_gqa_groups: int | None = None num_gqa_groups: Optional[int] = None
layernorm_type: str = 'layernorm' layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6 layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
...@@ -158,6 +252,7 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -158,6 +252,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
float32_attention_logits: bool = False float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = 'causal' self_attn_mask_type: str = 'causal'
self_attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
enable_relative_embedding: bool = True enable_relative_embedding: bool = True
...@@ -226,6 +321,7 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -226,6 +321,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
float32_attention_logits=self.float32_attention_logits, float32_attention_logits=self.float32_attention_logits,
layer_type=self.layer_type, layer_type=self.layer_type,
self_attn_mask_type=self.self_attn_mask_type, self_attn_mask_type=self.self_attn_mask_type,
self_attn_bias_type=self.self_attn_bias_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb, enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows, rotary_pos_emb_windows=self.rotary_pos_emb_windows,
enable_relative_embedding=self.enable_relative_embedding, enable_relative_embedding=self.enable_relative_embedding,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment