Unverified Commit 9b2fed51 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Refine MHA API and add DPA API (#653)



* Refine MHA API
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Reuse func from the flax
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* DPA draft
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* qkv packed draft
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix test_layer with fused attn
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add attn_bias_type and enhance a few code flow
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Move scale_factor from __call__ to init
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance the docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add DPA public API and tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix conflict
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add qkv separate fused attn
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Apply BSHD_BSHD_BSHD format
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove debug log
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add fused attention layer tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add NVTE_FUSED_ATTN docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fine-grained fused attn settings
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove the default value of num_attetnion_head and head_dim
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add teardown for fused attn env
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Unify the Optional notation
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix Pre/Post scale bias comments
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add no_mask tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add checkpoint_name for fused attn
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix the fused attn batcher
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent fb2f952a
......@@ -45,6 +45,9 @@ Modules
.. autoapiclass:: transformer_engine.jax.flax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.flax.DotProductAttention(head_dim, num_heads, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.flax.MultiHeadAttention(head_dim, num_heads, **kwargs)
:members: __call__
......
......@@ -20,7 +20,7 @@ from jax import value_and_grad, jit
from jax.typing import ArrayLike, DTypeLike
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn, fused_attn
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
......@@ -144,6 +144,9 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
kv = jnp.concatenate((key, value), axis=-3)
return cross_fused_attn(query, kv, bias, mask, dropout_rng,
**kwargs).astype(query.dtype)
case QKVLayout.BSHD_BSHD_BSHD:
return fused_attn(query, key, value, bias, mask, dropout_rng,
**kwargs).astype(query.dtype)
@dataclass
......@@ -337,6 +340,7 @@ class FusedAttnRunner:
@pytest.mark.parametrize('qkv_layout', [
pytest.param(QKVLayout.BS3HD, id='qkvpacked'),
pytest.param(QKVLayout.BSHD_BS2HD, id='kvpacked'),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id='separate'),
])
@pytest.mark.parametrize('dropout_prob', [0., 0.1])
@pytest.mark.parametrize('is_training',
......
......@@ -2,6 +2,7 @@
#
# See LICENSE for license information.
import os
from functools import partial
import flax
......@@ -20,6 +21,16 @@ from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
is_fp8_supported, reason = is_fp8_available()
@pytest.fixture(autouse=True, scope='module')
def enable_fused_attn():
"""
Enable fused attention
"""
os.environ["NVTE_FUSED_ATTN"] = "1"
yield
del os.environ["NVTE_FUSED_ATTN"]
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
......@@ -93,6 +104,7 @@ _KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb"
BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_DROPOUT_RATE: 0,
}
ATTRS = [{
......@@ -221,6 +233,7 @@ class TestEncoderLayer:
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
del data_rng, init_rng, apply_rng
......@@ -282,9 +295,6 @@ class TestEncoderLayer:
test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer,
apply_rng)
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
def reorganize_test_wgrad(test_wgrad, attrs):
num_heads = attrs.get(_KEY_OF_NUM_HEADS)
num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
......@@ -328,6 +338,10 @@ class TestEncoderLayer:
del unfreeze_test_wgrad['mlp']['wo_kernel']
return unfreeze_test_wgrad
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
compare_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol,
......@@ -430,6 +444,7 @@ class TestDecoderLayer:
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
del data_rng, init_rng, apply_rng
......@@ -492,9 +507,6 @@ class TestDecoderLayer:
test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer,
apply_rng)
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
def reorganize_test_wgrad(test_wgrad, attrs):
num_heads = attrs.get(_KEY_OF_NUM_HEADS)
num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
......@@ -547,6 +559,9 @@ class TestDecoderLayer:
del unfreeze_test_wgrad['mlp']['wo_kernel']
return unfreeze_test_wgrad
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
compare_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol,
......
......@@ -2,6 +2,7 @@
#
# See LICENSE for license information.
import os
from functools import partial
from typing import Dict
......@@ -14,12 +15,14 @@ import pytest
from utils import assert_allclose
from transformer_engine_jax import get_device_compute_capability
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax import fp8_autocast, update_fp8_metas, update_collections
from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral
from transformer_engine.jax.flax import LayerNorm as flax_LayerNorm
from transformer_engine.jax.flax import LayerNormMLP as flax_LayerNormMLP
from transformer_engine.jax.flax import MultiHeadAttention as flax_MultiHeadAttention
from transformer_engine.jax.flax import DotProductAttention as flax_DotProductAttention
from transformer_engine.jax.flax import RelativePositionBiases as flax_RelativePositionBiases
from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLayer
from transformer_engine.jax.flax.module import Softmax
......@@ -27,8 +30,8 @@ from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from transformer_engine.jax.praxis import LayerNorm
from transformer_engine.jax.praxis import FusedSoftmax
from transformer_engine.jax.praxis import LayerNormLinear, LayerNormMLP, Linear
from transformer_engine.jax.praxis import MultiHeadAttention, RelativePositionBiases
from transformer_engine.jax.praxis import TransformerEngineBaseLayer
from transformer_engine.jax.praxis import DotProductAttention, MultiHeadAttention
from transformer_engine.jax.praxis import RelativePositionBiases, TransformerEngineBaseLayer
from transformer_engine.jax.praxis import TransformerLayer, TransformerLayerType
from transformer_engine.jax.softmax import SoftmaxType
......@@ -40,6 +43,19 @@ ENABLE_FP8 = [False, True]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]
@pytest.fixture(autouse=True, scope='module')
def enable_fused_attn():
"""
Enable fused attn for hopper+ arch.
Fused attn kernels on pre-hopper arch are not deterministic.
"""
if get_device_compute_capability(0) >= 90:
os.environ["NVTE_FUSED_ATTN"] = "1"
yield
if "NVTE_FUSED_ATTN" in os.environ:
del os.environ["NVTE_FUSED_ATTN"]
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
......@@ -101,6 +117,7 @@ class TestLayer:
lyr_name = self.get_layer_name()
if 'params' in flax_variables:
synced_praxis_variables['params'][lyr_name]['cld'] = \
flax.core.unfreeze(flax_variables['params'])
......@@ -111,6 +128,7 @@ class TestLayer:
lyr_name = self.get_layer_name()
if 'params' in synced_praxis_grads:
synced_praxis_grads['params'] = \
synced_praxis_grads['params'][lyr_name]['cld']
......@@ -671,6 +689,86 @@ class TestRelativePositionBias(TestLayer):
assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)
class DotProductAttnAttr:
ATTN_MASK_TYPE = 'attn_mask_type'
NUM_GQA_GROUPS = 'num_gqa_groups'
TRANSPOSE_BS = 'transpose_batch_sequence'
SCALE_FACTOR = 'scale_factor'
ATTRS = [{
ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125,
}, {
ATTN_MASK_TYPE: 'padding_causal',
TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125,
}, {
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125,
}, {
ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: False,
SCALE_FACTOR: 0.125,
}, {
ATTN_MASK_TYPE: 'padding_causal',
TRANSPOSE_BS: False,
SCALE_FACTOR: 2.,
}, {
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: False,
SCALE_FACTOR: 1.,
}, {
ATTN_MASK_TYPE: 'no_mask',
TRANSPOSE_BS: False,
SCALE_FACTOR: 1.,
}]
class TestDotProductAttn(TestLayer):
def input_getter(self, shape, dtype):
key = jax.random.PRNGKey(seed=1234)
q_key, k_key, v_key = jax.random.split(key, 3)
return list(map(partial(jax.random.normal, shape=shape, dtype=dtype),
[q_key, k_key, v_key]))
def get_layer_name(self):
return 'dot_product_attn'
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
head_dim = 64
num_attention_heads = 16
num_gqa_groups = num_attention_heads
attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE]
transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS]
praxis_p = pax_fiddle.Config(DotProductAttention,
name='mha',
dtype=dtype,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence)
flax_cls = partial(flax_DotProductAttention,
dtype=dtype,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence)
return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', [(32, 128, 16, 64)])
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class MultiHeadAttnAttr:
USE_BIAS = 'use_bias'
LN_TYPE = 'layernorm_type'
......@@ -730,36 +828,38 @@ class TestMultiHeadAttn(TestLayer):
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
head_dim = 64
num_heads = 16
num_attention_heads = 16
num_gqa_groups = attrs[MultiHeadAttnAttr.NUM_GQA_GROUPS] \
if MultiHeadAttnAttr.NUM_GQA_GROUPS in attrs else None
layernorm_type = attrs[MultiHeadAttnAttr.LN_TYPE]
zero_centered_gamma = attrs[MultiHeadAttnAttr.ZERO_CEN]
kernel_init = WeightInit.Gaussian(1.0)
use_bias = attrs[MultiHeadAttnAttr.USE_BIAS]
bias_init = WeightInit.Constant(0.0)
apply_residual_connection_post_layernorm = False
output_layernorm = False
input_layernorm = False
return_layernorm_output = False
attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE]
fuse_qkv: bool = True
fuse_qkv_params = True
transpose_batch_sequence = True
scale_attn_logits = False
scaled_query_init = True
float32_logits = False
praxis_p = pax_fiddle.Config(
MultiHeadAttention,
praxis_p = pax_fiddle.Config(MultiHeadAttention,
name='mha',
dtype=dtype,
head_dim=head_dim,
num_heads=num_heads,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
params_init=kernel_init,
use_bias=use_bias,
bias_init=bias_init,
apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm,
output_layernorm=output_layernorm,
return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm,
attn_mask_type=attn_mask_type,
fuse_qkv=fuse_qkv,
fuse_qkv_params=fuse_qkv_params,
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init,
......@@ -768,16 +868,17 @@ class TestMultiHeadAttn(TestLayer):
flax_MultiHeadAttention,
dtype=dtype,
head_dim=head_dim,
num_heads=num_heads,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
use_bias=use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm,
output_layernorm=output_layernorm,
return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm,
attn_mask_type=attn_mask_type,
fuse_qkv=fuse_qkv,
fuse_qkv_params=fuse_qkv_params,
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init,
......@@ -1024,6 +1125,7 @@ class TestTransformer(TestLayer):
enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE]
enable_relative_embedding = True
relative_embedding = pax_fiddle.Config(RelativePositionBiases,
dtype=dtype,
num_attention_heads=num_attention_heads)
drop_path = 0.0
transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS]
......
......@@ -934,7 +934,7 @@ class EncoderLayer(nn.Module):
y = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="output_layer_norm")(y)
name="output_layernorm")(y)
return y
......@@ -1090,7 +1090,7 @@ class DecoderLayer(nn.Module):
z = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="output_layer_norm")(z)
name="output_layernorm")(z)
return z
......
......@@ -105,8 +105,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool flag_m512 = false;
bool flag_arb = false;
if ((sm_arch_ == 80 || sm_arch_ == 90)
&& (max_seqlen_q <= 512)
&& (max_seqlen_kv <= 512)
&& (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0)
&& (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0)
&& (head_dim == 64)
&& (num_attn_heads == num_gqa_groups)
&& ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
......
This diff is collapsed.
......@@ -53,6 +53,8 @@ pybind11::dict Registrations() {
dict["te_self_fused_attn_backward"] = EncapsulateFunction(SelfFusedAttnBackward);
dict["te_cross_fused_attn_forward"] = EncapsulateFunction(CrossFusedAttnForward);
dict["te_cross_fused_attn_backward"] = EncapsulateFunction(CrossFusedAttnBackward);
dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward);
dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward);
return dict;
}
......@@ -74,6 +76,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_self_fused_attn_bwd_workspace_sizes", &GetSelfFusedAttnBackwardWorkspaceSizes);
m.def("get_cross_fused_attn_fwd_workspace_sizes", &GetCrossFusedAttnForwardWorkspaceSizes);
m.def("get_cross_fused_attn_bwd_workspace_sizes", &GetCrossFusedAttnBackwardWorkspaceSizes);
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte)
......@@ -98,7 +102,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local())
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD);
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD);
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend)
......
......@@ -1253,7 +1253,6 @@ pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// TODO(rewang): add bias for cross attn?
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
// FP16/BF16 doesn't use this tensor
......@@ -1488,5 +1487,265 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
nvte_tensor_pack_destroy(&aux_input_tensors);
}
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input buffers from XLA
void *q = buffers[0];
void *k = buffers[1];
void *v = buffers[2];
void *bias = buffers[3];
void *q_cu_seqlens = buffers[4];
void *kv_cu_seqlens = buffers[5];
void *seed = buffers[6];
// output buffers from XLA
void *output = buffers[7];
void *softmax_aux = buffers[8];
void *rng_state = buffers[9];
void *workspace = buffers[10];
// tensor sizes
auto batch_size = descriptor.batch_size;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto num_heads = descriptor.num_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto o_tensor = TensorWrapper(output, q_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// prep RNG state
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
// auxiliary tensors (to be propagated to the backward pass later)
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
softmax_aux);
// cuDNN workspace
auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
descriptor.wkspace_dtype);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input buffers from XLA
void *q = buffers[0];
void *k = buffers[1];
void *v = buffers[2];
void *bias = buffers[3];
void *softmax_aux = buffers[4];
void *rng_state = buffers[5];
void *output = buffers[6];
void *doutput = buffers[7];
void *q_cu_seqlens = buffers[8];
void *kv_cu_seqlens = buffers[9];
// output buffers from XLA
void *dq = buffers[10];
void *dk = buffers[11];
void *dv = buffers[12];
void *dbias = buffers[13];
void *workspace = buffers[14];
// tensor sizes
auto batch_size = descriptor.batch_size;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto num_heads = descriptor.num_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dk_tensor = TensorWrapper(dk, k_shape, dtype);
auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
// auxiliary tensors (propagated from the forward pass)
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias);
// cuDNN workspace
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_input_tensors);
}
} // namespace jax
} // namespace transformer_engine
......@@ -236,6 +236,20 @@ pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
} // namespace jax
} // namespace transformer_engine
......
......@@ -5,11 +5,19 @@
from .module import DenseGeneral, LayerNorm
from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from .transformer import extend_logical_axis_rules
from .transformer import MultiHeadAttention, RelativePositionBiases
from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases
from .transformer import TransformerLayer, TransformerLayerType
__all__ = [
'DenseGeneral', 'LayerNorm', 'LayerNormDenseGeneral', 'LayerNormMLP',
'TransformerEngineBase', 'extend_logical_axis_rules', 'MultiHeadAttention',
'RelativePositionBiases', 'TransformerLayer', 'TransformerLayerType',
'DenseGeneral',
'LayerNorm',
'LayerNormDenseGeneral',
'LayerNormMLP',
'TransformerEngineBase',
'extend_logical_axis_rules',
'DotProductAttention',
'MultiHeadAttention',
'RelativePositionBiases',
'TransformerLayer',
'TransformerLayerType',
]
This diff is collapsed.
......@@ -16,6 +16,7 @@ from transformer_engine_jax import NVTE_QKV_Layout
from .cpp_extensions import FusedAttnHelper
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
from .cpp_extensions import fused_attn_fwd, fused_attn_bwd
class AttnBiasType(Enum):
......@@ -37,6 +38,21 @@ class QKVLayout(Enum):
"""QKV layout"""
BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD
BSHD_BSHD_BSHD = NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD
def canonicalize_attn_mask_type(attn_mask_type: str):
"""Convert string attn_mask_type to AttnMaskType
TE-JAX currently fall back to the padding version kernels for the libraries integration.
The overhead between padding and non-padding version should be small.
However, we will lease this limitation in the near feature.
"""
if attn_mask_type in ['causal', 'padding_causal']:
return AttnMaskType.PADDING_CAUSAL_MASK
if attn_mask_type in ['no_mask', 'padding']:
return AttnMaskType.PADDING_MASK
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type={'no_mask', 'padding', 'causal', 'padding_causal'}")
def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type,
......@@ -83,6 +99,10 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda
seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
if mask is None:
batch, seqlen, *_ = qkv.shape
actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32)
else:
mask = jnp.logical_not(mask)
actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
......@@ -159,13 +179,18 @@ def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask:
def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
if mask is None:
batch, s_q, *_ = q.shape
s_kv = kv.shape[1]
q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
else:
mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else:
# When mask is padding + causal, the actual seqlen is not the last row, use max to find it
# When mask is causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = cross_fused_attn_fwd(q,
......@@ -179,7 +204,9 @@ def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mas
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output = checkpoint_name(output, 'context')
softmax_aux = checkpoint_name(softmax_aux, 'context')
rng_state = checkpoint_name(rng_state, 'context')
return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen)
......@@ -209,3 +236,100 @@ def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, d
_cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule)
def fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Dot product attention with the seperated query, key, value
"""
output = _fused_attn(q,
k,
v,
bias,
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10))
def _fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray,
mask: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float, dropout_probability: float,
is_training: bool):
output, _ = _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training)
return output
def _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
if mask is None:
batch, s_q, *_ = q.shape
s_kv = k.shape[1]
q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
else:
mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else:
# When mask is causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = fused_attn_fwd(q,
k,
v,
bias,
q_actual_seqlen,
kv_actual_seqlen,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output = checkpoint_name(output, 'context')
softmax_aux = checkpoint_name(softmax_aux, 'context')
rng_state = checkpoint_name(rng_state, 'context')
return output, (q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen,
kv_actual_seqlen)
def _fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz):
q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
grad_q, grad_k, grad_v, grad_bias = fused_attn_bwd(q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
dz,
q_actual_seqlen,
kv_actual_seqlen,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None
return grad_q, grad_k, grad_v, grad_bias, None, None
_fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule)
......@@ -4,5 +4,6 @@
"""Praxis related Modules"""
from .module import FusedSoftmax, LayerNorm
from .module import LayerNormLinear, LayerNormMLP, Linear, TransformerEngineBaseLayer
from .transformer import MultiHeadAttention, RelativePositionBiases, TransformerLayer
from .transformer import DotProductAttention, MultiHeadAttention
from .transformer import RelativePositionBiases, TransformerLayer
from ..flax.transformer import TransformerLayerType
......@@ -6,6 +6,7 @@ Praxis Modules related Transformer
"""
from functools import partial
from typing import Optional, Sequence, Tuple
import warnings
from praxis import pax_fiddle
from praxis.base_layer import WeightInit
......@@ -13,9 +14,11 @@ from praxis.pytypes import JTensor
from .module import TransformerEngineBaseLayer
from ..flax.transformer import TransformerLayerType
from ..flax.transformer import DotProductAttention as flax_DotProductAttention
from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention
from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases
from ..flax.transformer import TransformerLayer as flax_TransformerLayer
from ..fused_attn import AttnBiasType, AttnMaskType
class RelativePositionBiases(TransformerEngineBaseLayer):
......@@ -59,30 +62,117 @@ class RelativePositionBiases(TransformerEngineBaseLayer):
return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional)
class DotProductAttention(TransformerEngineBaseLayer):
"""DotProductAttention"""
head_dim: int = 0
num_attention_heads: int = 0
num_gqa_groups: Optional[int] = None
attention_dropout: float = 0.
attn_mask_type: AttnMaskType = 'causal'
attn_bias_type: AttnBiasType = None
dropout_rng_name: str = 'dropout'
float32_logits: bool = False
qkv_layout: str = 'bshd_bshd_bshd'
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
def setup(self) -> None:
"""setup"""
super().setup()
assert self.head_dim > 0, f'{self.head_dim=}'
assert self.num_attention_heads > 0, f'{self.num_attention_heads=}'
dpa_cls = partial(flax_DotProductAttention,
head_dim=self.head_dim,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
attention_dropout=self.attention_dropout,
dtype=self.dtype,
dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_logits,
qkv_layout=self.qkv_layout,
scale_factor=self.scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence)
self.create_layer("dot_product_attention", dpa_cls)
def __call__(self,
query: JTensor,
key: JTensor,
value: JTensor,
mask: Optional[JTensor] = None,
bias: Optional[JTensor] = None,
*,
deterministic: bool = False) -> JTensor:
"""__call__"""
return self.dot_product_attention(query,
key,
value,
mask,
bias,
deterministic=deterministic)
class MultiHeadAttention(TransformerEngineBaseLayer):
"""MultiHeadAttention"""
head_dim: int = 64
num_heads: int = 16
num_gqa_groups: int | None = None
dropout_rate: float = 0.
head_dim: int = 0
num_attention_heads: int = 0
num_gqa_groups: Optional[int] = None
attention_dropout: float = 0.
dropout_rng_name: str = 'dropout'
input_layernorm: bool = True
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
return_layernorm_output: bool = False
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
attn_mask_type: str = 'causal'
fuse_qkv: bool = True
attn_bias_type: Optional[str] = None
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
float32_logits: bool = False
# Deprecated parameters
num_heads: Optional[int] = None
dropout_rate: Optional[float] = None
output_layernorm: Optional[bool] = None
apply_residual_connection_post_layernorm: Optional[bool] = None
fuse_qkv: Optional[bool] = None
def __post_init__(self):
# Deal with the deprecated parameters
if self.num_heads is not None:
self.num_attention_heads = self.num_heads
warnings.warn(
f"{__class__}.num_heads is deprecated. It will be removed recently. "
f"Please uses {__class__}.num_attention_heads as the new API.", DeprecationWarning)
if self.dropout_rate is not None:
self.attention_dropout = self.dropout_rate
warnings.warn(
f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
f"Please use {__class__}.attention_dropout as the new API.", DeprecationWarning)
if self.apply_residual_connection_post_layernorm is not None:
warnings.warn(
f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
f"It will be removed recently, please use {__class__}.return_layernorm_output.",
DeprecationWarning)
if self.fuse_qkv is not None:
warnings.warn(
f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
f"Please use {__class__}.fuse_qkv_params as the new API.", DeprecationWarning)
assert self.output_layernorm is None, (
f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm.")
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_heads
super().__post_init__()
......@@ -91,24 +181,28 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
"""setup"""
super().setup()
assert self.head_dim > 0, f'{self.head_dim=}'
assert self.num_attention_heads > 0, f'{self.num_attention_heads=}'
mha_cls = partial(
flax_MultiHeadAttention,
dtype=self.dtype,
head_dim=self.head_dim,
num_heads=self.num_heads,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dropout_rate=self.dropout_rate,
attention_dropout=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name,
input_layernorm=self.input_layernorm,
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
return_layernorm_output=self.return_layernorm_output,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm,
attn_mask_type=self.attn_mask_type,
fuse_qkv=self.fuse_qkv,
attn_bias_type=self.attn_bias_type,
fuse_qkv_params=self.fuse_qkv_params,
transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits,
......@@ -140,7 +234,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
hidden_size: int = 512
mlp_hidden_size: int = 2048
num_attention_heads: int = 8
num_gqa_groups: int | None = None
num_gqa_groups: Optional[int] = None
layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
......@@ -158,6 +252,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = 'causal'
self_attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
enable_relative_embedding: bool = True
......@@ -226,6 +321,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
float32_attention_logits=self.float32_attention_logits,
layer_type=self.layer_type,
self_attn_mask_type=self.self_attn_mask_type,
self_attn_bias_type=self.self_attn_bias_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
enable_relative_embedding=self.enable_relative_embedding,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment