Unverified Commit 8f6c5248 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX][Common] Support GQA (#578)



* Support num_gqa_groups arguments
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add GQA support on the JAX bridge code
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix the kv stride of the arbitrary backend
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Complete rewrite fused attention tests and add GQA coverage
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Support unfused GQA
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Calculate seqlen before the primitive for the better perf
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add GQA layer tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Apply code style checks for te_jax
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Apply code style checks for tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add num_gqa_groups doc
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine the qkv_type
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Correct the variable naming
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Handle Max512 CAUSAL
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add WAR for the latest jax image
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent daad219f
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
set -xe set -xe
# WAR(rewang) for the "Check failed: reduction_kind.has_value()"
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_xla_runtime_executable=true"
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_* pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_*
...@@ -14,5 +14,7 @@ pytest -Wignore -v $TE_PATH/examples/jax/mnist ...@@ -14,5 +14,7 @@ pytest -Wignore -v $TE_PATH/examples/jax/mnist
# Make encoder tests to have run-to-run deterministic to have the stable CI results # Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
# WAR(rewang) for the "Check failed: reduction_kind.has_value()"
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_xla_runtime_executable=true"
pytest -Wignore -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py pytest -Wignore -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
This diff is collapsed.
...@@ -9,13 +9,14 @@ import jax ...@@ -9,13 +9,14 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import pytest import pytest
from transformer_engine.common.recipe import Format
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from utils import assert_allclose from utils import assert_allclose
from utils import DecoderLayer as RefDecoderLayer from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer from utils import EncoderLayer as RefEncoderLayer
from transformer_engine.common.recipe import Format
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
...@@ -85,8 +86,13 @@ _KEY_OF_LAYERNORM_TYPE = 'layernorm_type' ...@@ -85,8 +86,13 @@ _KEY_OF_LAYERNORM_TYPE = 'layernorm_type'
_KEY_OF_ZERO_CENTERED_GAMMA = 'zero_centered_gamma' _KEY_OF_ZERO_CENTERED_GAMMA = 'zero_centered_gamma'
_KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence' _KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence'
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits" _KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
_KEY_OF_NUM_HEADS = 'num_attention_heads'
_KEY_OF_NUM_GQA_GROUPS = 'num_gqa_groups'
BASE_ATTRS = {_KEY_OF_TRANSPOSE_BS: True} BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_NUM_HEADS: 8,
}
ATTRS = [{ ATTRS = [{
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
...@@ -129,6 +135,9 @@ ATTRS = [{ ...@@ -129,6 +135,9 @@ ATTRS = [{
_KEY_OF_DROPOUT_RATE: 0.0, _KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')), _KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True _KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4
}] }]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
...@@ -137,21 +146,13 @@ ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] ...@@ -137,21 +146,13 @@ ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
class TestEncoderLayer: class TestEncoderLayer:
@staticmethod @staticmethod
def sync_params(ref, target, attrs): def sync_params(ref, target):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
unfreeze_target = flax.core.unfreeze(target) unfreeze_target = flax.core.unfreeze(target)
if fuse_qkv: unfreeze_attn_scope = unfreeze_target['attention']
unfreeze_target['attention']['qkv']['kernel'] = \ ref_attn_scope = ref['attention']
jnp.reshape(ref['attention']['qkv']['kernel'], for key in ref_attn_scope.keys():
unfreeze_target['attention']['qkv']['kernel'].shape) unfreeze_attn_scope[key]['kernel'] = \
else: ref_attn_scope[key]['kernel'].reshape(unfreeze_attn_scope[key]['kernel'].shape)
unfreeze_target['attention']['query']['kernel'] = \
ref['attention']['query']['kernel']
unfreeze_target['attention']['key']['kernel'] = \
ref['attention']['key']['kernel']
unfreeze_target['attention']['value']['kernel'] = \
ref['attention']['value']['kernel']
unfreeze_target['mlp']['wi_kernel'] = \ unfreeze_target['mlp']['wi_kernel'] = \
jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape) jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
unfreeze_target['mlp']['wo_kernel'] = \ unfreeze_target['mlp']['wo_kernel'] = \
...@@ -196,7 +197,7 @@ class TestEncoderLayer: ...@@ -196,7 +197,7 @@ class TestEncoderLayer:
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs, test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks) test_masks)
ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params, attrs) ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params)
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng) ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng) test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)
...@@ -242,7 +243,7 @@ class TestEncoderLayer: ...@@ -242,7 +243,7 @@ class TestEncoderLayer:
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs, test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks) test_masks)
ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params, attrs) ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params)
if FP8Helper.is_fp8_enabled(): if FP8Helper.is_fp8_enabled():
for _ in range(4): for _ in range(4):
...@@ -266,7 +267,10 @@ class TestEncoderLayer: ...@@ -266,7 +267,10 @@ class TestEncoderLayer:
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
def reorganize_test_wgrad(test_wgrad, attrs): def reorganize_test_wgrad(test_wgrad, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) num_heads = attrs.get(_KEY_OF_NUM_HEADS)
num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) and \
num_heads == num_gqa_groups
attn_name = 'attention' attn_name = 'attention'
unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad) unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad)
...@@ -280,10 +284,12 @@ class TestEncoderLayer: ...@@ -280,10 +284,12 @@ class TestEncoderLayer:
unfreeze_test_wgrad['pre_attention_layer_norm']['ln_bias'] = \ unfreeze_test_wgrad['pre_attention_layer_norm']['ln_bias'] = \
unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias'] unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias'] del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
if fuse_qkv:
unfreeze_test_wgrad[attn_name]['qkv']['kernel'] = \ for key in unfreeze_test_wgrad[attn_name].keys():
jnp.reshape(unfreeze_test_wgrad[attn_name]['qkv']['kernel'], unfreeze_test_wgrad[attn_name][key]['kernel'] = \
(unfreeze_test_wgrad[attn_name]['qkv']['kernel'].shape[0], -1)) jnp.reshape(unfreeze_test_wgrad[attn_name][key]['kernel'],
(unfreeze_test_wgrad[attn_name][key]['kernel'].shape[0], -1))
unfreeze_test_wgrad['pre_mlp_layer_norm'] = {} unfreeze_test_wgrad['pre_mlp_layer_norm'] = {}
unfreeze_test_wgrad['pre_mlp_layer_norm']['scale'] = \ unfreeze_test_wgrad['pre_mlp_layer_norm']['scale'] = \
unfreeze_test_wgrad['mlp']['scale'] unfreeze_test_wgrad['mlp']['scale']
...@@ -348,26 +354,14 @@ class TestEncoderLayer: ...@@ -348,26 +354,14 @@ class TestEncoderLayer:
class TestDecoderLayer: class TestDecoderLayer:
@staticmethod @staticmethod
def sync_params(ref, target, attrs): def sync_params(ref, target):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
unfreeze_target = flax.core.unfreeze(target) unfreeze_target = flax.core.unfreeze(target)
if fuse_qkv: for scope in ['self_attention', 'encoder_decoder_attention']:
unfreeze_target['self_attention']['qkv']['kernel'] = \ unfreeze_scope = unfreeze_target[scope]
jnp.reshape(ref['self_attention']['qkv']['kernel'], ref_scope = ref[scope]
unfreeze_target['self_attention']['qkv']['kernel'].shape) for key in unfreeze_scope.keys():
unfreeze_target['encoder_decoder_attention']['kv']['kernel'] = \ unfreeze_scope[key]['kernel'] = \
jnp.reshape(ref['encoder_decoder_attention']['kv']['kernel'], ref_scope[key]['kernel'].reshape(unfreeze_scope[key]['kernel'].shape)
unfreeze_target['encoder_decoder_attention']['kv']['kernel'].shape)
else:
unfreeze_target['self_attention']['query']['kernel'] = \
ref['self_attention']['query']['kernel']
unfreeze_target['self_attention']['key']['kernel'] = \
ref['self_attention']['key']['kernel']
unfreeze_target['self_attention']['value']['kernel'] = \
ref['self_attention']['value']['kernel']
unfreeze_target['encoder_decoder_attention']['query']['kernel'] = \
ref['encoder_decoder_attention']['query']['kernel']
unfreeze_target['mlp']['wi_kernel'] = \ unfreeze_target['mlp']['wi_kernel'] = \
jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape) jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
unfreeze_target['mlp']['wo_kernel'] = \ unfreeze_target['mlp']['wo_kernel'] = \
...@@ -412,7 +406,7 @@ class TestDecoderLayer: ...@@ -412,7 +406,7 @@ class TestDecoderLayer:
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs, test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks) test_masks)
ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params, attrs) ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params)
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng) ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng) test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)
...@@ -459,7 +453,7 @@ class TestDecoderLayer: ...@@ -459,7 +453,7 @@ class TestDecoderLayer:
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs, test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks) test_masks)
ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params, attrs) ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params)
if FP8Helper.is_fp8_enabled(): if FP8Helper.is_fp8_enabled():
for _ in range(4): for _ in range(4):
...@@ -483,11 +477,14 @@ class TestDecoderLayer: ...@@ -483,11 +477,14 @@ class TestDecoderLayer:
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
def reorganize_test_wgrad(test_wgrad, attrs): def reorganize_test_wgrad(test_wgrad, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) num_heads = attrs.get(_KEY_OF_NUM_HEADS)
attn_name = 'self_attention' num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) and \
num_heads == num_gqa_groups
unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad) unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad)
if "output_layernorm" not in attrs: if "output_layernorm" not in attrs:
attn_name = 'self_attention'
unfreeze_test_wgrad['pre_self_attention_layer_norm'] = {} unfreeze_test_wgrad['pre_self_attention_layer_norm'] = {}
pre_attn_layer_key = 'qkv' if fuse_qkv else 'query' pre_attn_layer_key = 'qkv' if fuse_qkv else 'query'
unfreeze_test_wgrad['pre_self_attention_layer_norm']['scale'] = \ unfreeze_test_wgrad['pre_self_attention_layer_norm']['scale'] = \
...@@ -498,14 +495,11 @@ class TestDecoderLayer: ...@@ -498,14 +495,11 @@ class TestDecoderLayer:
unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias'] unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias'] del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
if fuse_qkv: for scope in ['self_attention', 'encoder_decoder_attention']:
unfreeze_test_wgrad[attn_name]['qkv']['kernel'] = \ for key in unfreeze_test_wgrad[scope].keys():
jnp.reshape(unfreeze_test_wgrad[attn_name]['qkv']['kernel'], unfreeze_test_wgrad[scope][key]['kernel'] = \
(unfreeze_test_wgrad[attn_name]['qkv']['kernel'].shape[0], -1)) jnp.reshape(unfreeze_test_wgrad[scope][key]['kernel'],
attn_name = 'encoder_decoder_attention' (unfreeze_test_wgrad[scope][key]['kernel'].shape[0], -1))
unfreeze_test_wgrad[attn_name]['kv']['kernel'] = \
jnp.reshape(unfreeze_test_wgrad[attn_name]['kv']['kernel'],
(unfreeze_test_wgrad[attn_name]['kv']['kernel'].shape[0], -1))
unfreeze_test_wgrad['pre_cross_attention_layer_norm'] = {} unfreeze_test_wgrad['pre_cross_attention_layer_norm'] = {}
unfreeze_test_wgrad['pre_cross_attention_layer_norm']['scale'] = \ unfreeze_test_wgrad['pre_cross_attention_layer_norm']['scale'] = \
......
...@@ -12,6 +12,8 @@ from praxis import pax_fiddle ...@@ -12,6 +12,8 @@ from praxis import pax_fiddle
from praxis.base_layer import WeightInit, DEFAULT_INIT_MUTABLE_LIST from praxis.base_layer import WeightInit, DEFAULT_INIT_MUTABLE_LIST
import pytest import pytest
from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax import fp8_autocast, update_fp8_metas, update_collections from transformer_engine.jax import fp8_autocast, update_fp8_metas, update_collections
from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral
...@@ -23,12 +25,12 @@ from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLaye ...@@ -23,12 +25,12 @@ from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLaye
from transformer_engine.jax.flax.module import Softmax from transformer_engine.jax.flax.module import Softmax
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from transformer_engine.jax.praxis import LayerNorm from transformer_engine.jax.praxis import LayerNorm
from transformer_engine.jax.praxis import FusedSoftmax, LayerNorm from transformer_engine.jax.praxis import FusedSoftmax
from transformer_engine.jax.praxis import LayerNormLinear, LayerNormMLP, Linear from transformer_engine.jax.praxis import LayerNormLinear, LayerNormMLP, Linear
from transformer_engine.jax.praxis import MultiHeadAttention, RelativePositionBiases from transformer_engine.jax.praxis import MultiHeadAttention, RelativePositionBiases
from transformer_engine.jax.praxis import TransformerEngineBaseLayer, TransformerLayer, TransformerLayerType from transformer_engine.jax.praxis import TransformerEngineBaseLayer
from transformer_engine.jax.praxis import TransformerLayer, TransformerLayerType
from transformer_engine.jax.softmax import SoftmaxType from transformer_engine.jax.softmax import SoftmaxType
from utils import assert_allclose
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
...@@ -674,6 +676,8 @@ class MultiHeadAttnAttr: ...@@ -674,6 +676,8 @@ class MultiHeadAttnAttr:
LN_TYPE = 'layernorm_type' LN_TYPE = 'layernorm_type'
ATTN_MASK_TYPE = 'attn_mask_type' ATTN_MASK_TYPE = 'attn_mask_type'
ZERO_CEN = 'zero_centered_gamma' ZERO_CEN = 'zero_centered_gamma'
NUM_ATTN_HEADS = 'num_attention_heads'
NUM_GQA_GROUPS = 'num_gqa_groups'
ATTRS = [{ ATTRS = [{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
...@@ -704,6 +708,13 @@ class MultiHeadAttnAttr: ...@@ -704,6 +708,13 @@ class MultiHeadAttnAttr:
LN_TYPE: 'rmsnorm', LN_TYPE: 'rmsnorm',
ZERO_CEN: False, ZERO_CEN: False,
ATTN_MASK_TYPE: 'causal' ATTN_MASK_TYPE: 'causal'
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal'
}] }]
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Utility for the TE layer tests"""
import functools import functools
import math import math
...@@ -28,6 +29,9 @@ Initializer = Callable[[PRNGKey, Shape, DType], Array] ...@@ -28,6 +29,9 @@ Initializer = Callable[[PRNGKey, Shape, DType], Array]
def is_devices_enough(required): def is_devices_enough(required):
"""
Check if the available GPUs is enough
"""
return len(jax.devices()) >= required return len(jax.devices()) >= required
...@@ -121,9 +125,9 @@ def dot_product_attention(query: Array, ...@@ -121,9 +125,9 @@ def dot_product_attention(query: Array,
query: queries for calculating attention with shape of `[batch, q_length, query: queries for calculating attention with shape of `[batch, q_length,
num_heads, qk_depth_per_head]`. num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of `[batch, kv_length, key: keys for calculating attention with shape of `[batch, kv_length,
num_heads, qk_depth_per_head]`. num_gqa_groups, qk_depth_per_head]`.
value: values to be used in attention with shape of `[batch, kv_length, value: values to be used in attention with shape of `[batch, kv_length,
num_heads, v_depth_per_head]`. num_gqa_groups, v_depth_per_head]`.
bias: bias for the attention weights. This should be broadcastable to the bias: bias for the attention weights. This should be broadcastable to the
shape `[batch, num_heads, q_length, kv_length]` This can be used for shape `[batch, num_heads, q_length, kv_length]` This can be used for
incorporating causal masks, padding masks, proximity bias, etc. incorporating causal masks, padding masks, proximity bias, etc.
...@@ -141,21 +145,31 @@ def dot_product_attention(query: Array, ...@@ -141,21 +145,31 @@ def dot_product_attention(query: Array,
batch_dim = 1 if transpose_batch_sequence else 0 batch_dim = 1 if transpose_batch_sequence else 0
assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], ( assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
'q, k, v batch dims must match.') 'q, k, v batch dims must match.')
assert query.shape[-2] == key.shape[-2] == value.shape[-2], ('q, k, v num_heads must match.')
sequence_dim = 0 if transpose_batch_sequence else 1 sequence_dim = 0 if transpose_batch_sequence else 1
assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.' assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.'
# Casting logits and softmax computation for float32 for model stability. # Casting logits and softmax computation for float32 for model stability.
if float32_logits: if float32_logits:
query = query.astype(jnp.float32) query = query.astype(jnp.float32)
key = key.astype(jnp.float32) key = key.astype(jnp.float32)
# `attn_weights`: [batch, num_heads, q_length, kv_length] # `attn_weights`: [batch, num_heads, groups, q_length, kv_length]
h_q, h_kv = query.shape[-2], key.shape[-2]
assert (h_q % h_kv == 0) and (h_q >= h_kv)
group_size = h_q // h_kv
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
if transpose_batch_sequence: if transpose_batch_sequence:
attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key) attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
else: else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
# reshape back to normal DPA shape for bias/softmax/dropout
b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
# Apply attention bias: masking, dropout, proximity bias, etc. # Apply attention bias: masking, dropout, proximity bias, etc.
if bias is not None: if bias is not None:
...@@ -174,11 +188,13 @@ def dot_product_attention(query: Array, ...@@ -174,11 +188,13 @@ def dot_product_attention(query: Array,
multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
attn_weights = attn_weights * multiplier attn_weights = attn_weights * multiplier
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
# Take the linear combination of `value`. # Take the linear combination of `value`.
if transpose_batch_sequence: if transpose_batch_sequence:
return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value) return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
class DenseGeneral(nn.Module): class DenseGeneral(nn.Module):
...@@ -235,7 +251,8 @@ class DenseGeneral(nn.Module): ...@@ -235,7 +251,8 @@ class DenseGeneral(nn.Module):
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes('bias', bias = nn_partitioning.param_with_axes('bias',
self.bias_init, (self.features,), self.bias_init,
self.features,
self.dtype, self.dtype,
axes=self.bias_axes) axes=self.bias_axes)
else: else:
...@@ -332,6 +349,7 @@ class MultiHeadAttention(nn.Module): ...@@ -332,6 +349,7 @@ class MultiHeadAttention(nn.Module):
Attributes: Attributes:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads. should be divisible by the number of heads.
num_gqa_groups: number of kv attention heads
head_dim: dimension of each head. head_dim: dimension of each head.
dtype: the dtype of the computation. dtype: the dtype of the computation.
dropout_rate: dropout rate dropout_rate: dropout rate
...@@ -340,9 +358,10 @@ class MultiHeadAttention(nn.Module): ...@@ -340,9 +358,10 @@ class MultiHeadAttention(nn.Module):
numerical issues with bfloat16. numerical issues with bfloat16.
""" """
num_heads: int num_heads: int = 8
head_dim: int num_gqa_groups: int | None = None
transpose_batch_sequence: bool head_dim: int = 64
transpose_batch_sequence: bool = True
dtype: DType = jnp.float32 dtype: DType = jnp.float32
dropout_rate: float = 0. dropout_rate: float = 0.
kernel_init: Initializer = None kernel_init: Initializer = None
...@@ -354,6 +373,8 @@ class MultiHeadAttention(nn.Module): ...@@ -354,6 +373,8 @@ class MultiHeadAttention(nn.Module):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal') self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -393,18 +414,24 @@ class MultiHeadAttention(nn.Module): ...@@ -393,18 +414,24 @@ class MultiHeadAttention(nn.Module):
Returns: Returns:
output of shape `[batch, length, q_features]`. output of shape `[batch, length, q_features]`.
""" """
projection = functools.partial(DenseGeneral, q_projection = functools.partial(DenseGeneral,
axis=-1, axis=-1,
features=self.num_heads * self.head_dim, features=self.num_heads * self.head_dim,
kernel_axes=('embed', 'joined_kv'), kernel_axes=('embed', 'joined_kv'),
dtype=self.dtype) dtype=self.dtype)
kv_projection = functools.partial(DenseGeneral,
axis=-1,
features=self.num_gqa_groups * self.head_dim,
kernel_axes=('embed', 'joined_kv'),
dtype=self.dtype)
# NOTE: T5 does not explicitly rescale the attention logits by # NOTE: T5 does not explicitly rescale the attention logits by
# 1/sqrt(depth_kq)! This is folded into the initializers of the # 1/sqrt(depth_kq)! This is folded into the initializers of the
# linear transformations, which is equivalent under Adafactor # linear transformations, which is equivalent under Adafactor
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
query_init = lambda *args: self.kernel_init(*args) / ( # pylint: disable=unnecessary-lambda-assignment query_init = lambda *args: self.kernel_init(*args) / (depth_scaling
depth_scaling if self.scaled_query_init else 1.0) if self.scaled_query_init else 1.0)
# Project inputs_q to multi-headed q/k/v # Project inputs_q to multi-headed q/k/v
# dimensions are then [batch, length, num_heads, head_dim] # dimensions are then [batch, length, num_heads, head_dim]
...@@ -417,13 +444,17 @@ class MultiHeadAttention(nn.Module): ...@@ -417,13 +444,17 @@ class MultiHeadAttention(nn.Module):
v_shape = (shape[0], shape[1] // 3) v_shape = (shape[0], shape[1] // 3)
q_kernel = query_init(key, q_shape, dtype) q_kernel = query_init(key, q_shape, dtype)
k_kernel = self.kernel_init(key, k_shape, dtype) # pylint: disable=too-many-function-args k_kernel = self.kernel_init(key, k_shape, dtype)
v_kernel = self.kernel_init(key, v_shape, dtype) # pylint: disable=too-many-function-args v_kernel = self.kernel_init(key, v_shape, dtype)
return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype) return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype)
is_self_attn = (inputs_q is inputs_kv)
is_gqa = (self.num_heads != self.num_gqa_groups)
is_qkvpack = (is_self_attn and not is_gqa)
if self.fuse_qkv: if self.fuse_qkv:
if inputs_q is inputs_kv: if is_qkvpack:
qkv_proj = DenseGeneral(axis=-1, qkv_proj = DenseGeneral(axis=-1,
features=self.num_heads * self.head_dim * 3, features=self.num_heads * self.head_dim * 3,
kernel_axes=('embed', 'joined_kv'), kernel_axes=('embed', 'joined_kv'),
...@@ -436,24 +467,24 @@ class MultiHeadAttention(nn.Module): ...@@ -436,24 +467,24 @@ class MultiHeadAttention(nn.Module):
if self.scale_attn_logits: if self.scale_attn_logits:
query = query / depth_scaling query = query / depth_scaling
else: else:
query = projection(kernel_init=query_init, name='query')( \ query = q_projection(kernel_init=query_init, name='query')( \
(inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q) (inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q)
kv_proj = DenseGeneral(axis=-1, kv_proj = DenseGeneral(axis=-1,
features=self.num_heads * self.head_dim * 2, features=self.num_gqa_groups * self.head_dim * 2,
kernel_axes=('embed', 'joined_kv'), kernel_axes=('embed', 'joined_kv'),
kernel_init=self.kernel_init, kernel_init=self.kernel_init,
name='kv', name='kv',
dtype=self.dtype)(inputs_kv) dtype=self.dtype)(inputs_kv)
key, value = jnp.split(kv_proj, [self.num_heads * self.head_dim], axis=-1) key, value = jnp.split(kv_proj, [self.num_gqa_groups * self.head_dim], axis=-1)
else: else:
query = projection(kernel_init=query_init, name='query')( \ query = q_projection(kernel_init=query_init, name='query')( \
(inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q) (inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q)
key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv) key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv) value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim)) query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim)) key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim)) value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
query = nn_partitioning.with_sharding_constraint(query, query = nn_partitioning.with_sharding_constraint(query,
...@@ -476,7 +507,7 @@ class MultiHeadAttention(nn.Module): ...@@ -476,7 +507,7 @@ class MultiHeadAttention(nn.Module):
# fusion optimization. This also enables the "scatter via one-hot # fusion optimization. This also enables the "scatter via one-hot
# broadcast" trick, which means we do a one-hot broadcast instead of a # broadcast" trick, which means we do a one-hot broadcast instead of a
# scatter/gather operations, resulting in a 3-4x speedup in practice. # scatter/gather operations, resulting in a 3-4x speedup in practice.
swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3]) # pylint: disable=unnecessary-lambda-assignment swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3])
cached_key = self.variable('cache', 'cached_key', jnp.zeros, swap_dims(key.shape), cached_key = self.variable('cache', 'cached_key', jnp.zeros, swap_dims(key.shape),
key.dtype) key.dtype)
cached_value = self.variable('cache', 'cached_value', jnp.zeros, swap_dims(value.shape), cached_value = self.variable('cache', 'cached_value', jnp.zeros, swap_dims(value.shape),
...@@ -755,7 +786,8 @@ class RelativePositionBiases(nn.Module): ...@@ -755,7 +786,8 @@ class RelativePositionBiases(nn.Module):
class EncoderLayer(nn.Module): class EncoderLayer(nn.Module):
"""Transformer encoder layer.""" """Transformer encoder layer."""
relative_embedding: nn.Module = None relative_embedding: nn.Module = None
num_heads: int = 8 num_attention_heads: int = 8
num_gqa_groups: int | None = None
head_dim: int = 64 head_dim: int = 64
dropout_rate: float = 0.1 dropout_rate: float = 0.1
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
...@@ -773,6 +805,11 @@ class EncoderLayer(nn.Module): ...@@ -773,6 +805,11 @@ class EncoderLayer(nn.Module):
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False fuse_mlp_wi: bool = False
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact @nn.compact
def __call__(self, inputs, encoder_mask=None, deterministic=False): def __call__(self, inputs, encoder_mask=None, deterministic=False):
# Relative position embedding as attention biases. # Relative position embedding as attention biases.
...@@ -782,7 +819,7 @@ class EncoderLayer(nn.Module): ...@@ -782,7 +819,7 @@ class EncoderLayer(nn.Module):
if self.relative_embedding is None: if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32, rel_emb = RelativePositionBiases(num_buckets=32,
max_distance=128, max_distance=128,
num_heads=self.num_heads, num_heads=self.num_attention_heads,
dtype=self.dtype, dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling( embedding_init=nn.initializers.variance_scaling(
1.0, 'fan_avg', 'uniform'), 1.0, 'fan_avg', 'uniform'),
...@@ -807,7 +844,8 @@ class EncoderLayer(nn.Module): ...@@ -807,7 +844,8 @@ class EncoderLayer(nn.Module):
x = inputs x = inputs
# [batch, length, emb_dim] -> [batch, length, emb_dim] # [batch, length, emb_dim] -> [batch, length, emb_dim]
x = MultiHeadAttention(num_heads=self.num_heads, x = MultiHeadAttention(num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype, dtype=self.dtype,
head_dim=self.head_dim, head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
...@@ -868,7 +906,8 @@ class EncoderLayer(nn.Module): ...@@ -868,7 +906,8 @@ class EncoderLayer(nn.Module):
class DecoderLayer(nn.Module): class DecoderLayer(nn.Module):
"""Transformer decoder layer that attends to the encoder.""" """Transformer decoder layer that attends to the encoder."""
relative_embedding: nn.Module = None relative_embedding: nn.Module = None
num_heads: int = 8 num_attention_heads: int = 8
num_gqa_groups: int | None = None
head_dim: int = 64 head_dim: int = 64
dropout_rate: float = 0.1 dropout_rate: float = 0.1
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
...@@ -886,6 +925,11 @@ class DecoderLayer(nn.Module): ...@@ -886,6 +925,11 @@ class DecoderLayer(nn.Module):
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False fuse_mlp_wi: bool = False
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact @nn.compact
def __call__(self, def __call__(self,
inputs, inputs,
...@@ -903,7 +947,7 @@ class DecoderLayer(nn.Module): ...@@ -903,7 +947,7 @@ class DecoderLayer(nn.Module):
if self.relative_embedding is None: if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32, rel_emb = RelativePositionBiases(num_buckets=32,
max_distance=128, max_distance=128,
num_heads=self.num_heads, num_heads=self.num_attention_heads,
dtype=self.dtype, dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling( embedding_init=nn.initializers.variance_scaling(
1.0, 'fan_avg', 'uniform'), 1.0, 'fan_avg', 'uniform'),
...@@ -928,7 +972,8 @@ class DecoderLayer(nn.Module): ...@@ -928,7 +972,8 @@ class DecoderLayer(nn.Module):
x = inputs x = inputs
# Self-attention block # Self-attention block
x = MultiHeadAttention(num_heads=self.num_heads, x = MultiHeadAttention(num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype, dtype=self.dtype,
head_dim=self.head_dim, head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
...@@ -960,7 +1005,8 @@ class DecoderLayer(nn.Module): ...@@ -960,7 +1005,8 @@ class DecoderLayer(nn.Module):
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
residual = y residual = y
y = MultiHeadAttention(num_heads=self.num_heads, y = MultiHeadAttention(num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype, dtype=self.dtype,
head_dim=self.head_dim, head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
...@@ -1012,6 +1058,9 @@ class DecoderLayer(nn.Module): ...@@ -1012,6 +1058,9 @@ class DecoderLayer(nn.Module):
def make_causal_mask(batch, seqlen, dtype=jnp.uint8): def make_causal_mask(batch, seqlen, dtype=jnp.uint8):
"""
Generate causal mask
"""
shape = (batch, seqlen) shape = (batch, seqlen)
idxs = jnp.broadcast_to(jnp.arange(shape[-1], dtype=jnp.int32), shape) idxs = jnp.broadcast_to(jnp.arange(shape[-1], dtype=jnp.int32), shape)
...@@ -1022,6 +1071,9 @@ def make_causal_mask(batch, seqlen, dtype=jnp.uint8): ...@@ -1022,6 +1071,9 @@ def make_causal_mask(batch, seqlen, dtype=jnp.uint8):
def make_self_mask(batch, seqlen, dtype=jnp.uint8): def make_self_mask(batch, seqlen, dtype=jnp.uint8):
"""
Generate attention mask
"""
shape = (batch, seqlen) shape = (batch, seqlen)
mask = jnp.ones((*shape, shape[-1])) mask = jnp.ones((*shape, shape[-1]))
mask = jnp.expand_dims(mask, axis=-3) mask = jnp.expand_dims(mask, axis=-3)
...@@ -1057,7 +1109,7 @@ def assert_allclose( ...@@ -1057,7 +1109,7 @@ def assert_allclose(
dtype = actual.dtype dtype = actual.dtype
# Determine tolerances # Determine tolerances
tols = dict() tols = {}
if rtol is None or atol is None: if rtol is None or atol is None:
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if rtol is not None: if rtol is not None:
......
...@@ -573,14 +573,14 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -573,14 +573,14 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const DType QKV_type = input_QKV->data.dtype; const auto QKV_type = input_QKV->data.dtype;
void *devPtrQKV = input_QKV->data.dptr; void *devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = 2 * num_attn_heads * head_dim; stride = typeToSize(QKV_type) * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = 2 * head_dim; stride = typeToSize(QKV_type) * head_dim;
} }
void *devPtrQ = static_cast<void *>(devPtrQKV); void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride); void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
...@@ -677,14 +677,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea ...@@ -677,14 +677,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const auto QKV_type = input_QKV->data.dtype;
void *devPtrQKV = input_QKV->data.dptr; void *devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = 2 * num_attn_heads * head_dim; stride = typeToSize(QKV_type) * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = 2 * head_dim; stride = typeToSize(QKV_type) * head_dim;
} }
void *devPtrQ = devPtrQKV; void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride); void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
...@@ -712,7 +713,6 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea ...@@ -712,7 +713,6 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
void* devPtrDropoutOffset = reinterpret_cast<void *>( void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1); reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const auto qkv_type = input_QKV->data.dtype;
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_attn_heads, fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_attn_heads,
...@@ -723,7 +723,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea ...@@ -723,7 +723,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlens, devPtrCuSeqlens, devPtrCuSeqlens, devPtrCuSeqlens,
get_cudnn_fe_dtype(qkv_type), workspace->data.dptr, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle); &workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
...@@ -750,15 +750,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -750,15 +750,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype; const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr; void *devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr; void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = 2 * num_attn_heads * head_dim; stride = typeToSize(QKV_type) * num_gqa_groups * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = 2 * head_dim; stride = typeToSize(QKV_type) * head_dim;
} }
void *devPtrK = devPtrKV; void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride); void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
...@@ -860,15 +860,14 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -860,15 +860,14 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
using namespace transformer_engine; using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype; const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr; void *devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr; void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = 2 * num_attn_heads * head_dim; stride = typeToSize(QKV_type) * num_gqa_groups * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = 2 * head_dim; stride = typeToSize(QKV_type) * head_dim;
} }
void *devPtrK = devPtrKV; void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride); void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
...@@ -935,7 +934,7 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -935,7 +934,7 @@ void fused_attn_arbitrary_seqlen_fwd(
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype; const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr; void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_K->data.dptr; void *devPtrK = input_K->data.dptr;
void *devPtrV = input_V->data.dptr; void *devPtrV = input_V->data.dptr;
......
...@@ -1651,13 +1651,10 @@ class FusedAttnHelper: ...@@ -1651,13 +1651,10 @@ class FusedAttnHelper:
def get_fused_attn_backend(self): def get_fused_attn_backend(self):
"""Get the fused attention kernel backend""" """Get the fused attention kernel backend"""
return transformer_engine_jax.get_fused_attn_backend(jax_dtype_to_te_dtype(self.q_type), return transformer_engine_jax.get_fused_attn_backend(
jax_dtype_to_te_dtype(self.kv_type), jax_dtype_to_te_dtype(self.q_type), jax_dtype_to_te_dtype(self.kv_type),
self.qkv_layout, self.attn_bias_type, self.qkv_layout, self.attn_bias_type, self.attn_mask_type, self.dropout_probability,
self.attn_mask_type, self.num_heads_q, self.num_heads_kv, self.max_seqlen_q, self.max_seqlen_kv,
self.dropout_probability,
self.num_heads_q, self.num_heads_kv,
self.max_seqlen_q, self.max_seqlen_kv,
self.head_dim) self.head_dim)
...@@ -1701,12 +1698,11 @@ class _FusedAttnRNGStateChecker: ...@@ -1701,12 +1698,11 @@ class _FusedAttnRNGStateChecker:
return seed return seed
def generate_cu_seqlen(mask): def generate_cu_seqlen(actual_seqlen):
""" """
Generating cumsum seqlen for a batch Generating cumsum seqlen for a batch
""" """
seqlen = jnp.sum(mask == 0, axis=(-1, -2), dtype=jnp.int32) cu_seqlen = jnp.cumsum(actual_seqlen)
cu_seqlen = jnp.cumsum(seqlen)
cu_seqlen = jnp.hstack((0, cu_seqlen)) cu_seqlen = jnp.hstack((0, cu_seqlen))
return cu_seqlen return cu_seqlen
...@@ -1722,13 +1718,13 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -1722,13 +1718,13 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract(qkv_aval, bias_aval, mask_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, def abstract(qkv_aval, bias_aval, seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training): attn_mask_type, scaling_factor, dropout_probability, is_training):
""" """
Self fused attention fwd abstract Self fused attention fwd abstract
""" """
# outer_primitve is squeezed_mask, inner_primitive is cu_seqlen # outer_primitve is seqlen, inner_primitive is cu_seqlen
del mask_or_cu_seqlen_aval, scaling_factor, is_training del seqlen_or_cu_seqlen_aval, scaling_factor, is_training
qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype) qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
*batch_shape, max_seqlen, nqkv, num_head, head_dim = qkv_aval.shape *batch_shape, max_seqlen, nqkv, num_head, head_dim = qkv_aval.shape
assert nqkv == 3 assert nqkv == 3
...@@ -1781,19 +1777,20 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -1781,19 +1777,20 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability, batch, num_head, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training) dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out return out
@staticmethod @staticmethod
def impl(qkv, bias, squeezed_mask, seed, attn_bias_type, attn_mask_type, scaling_factor, def impl(qkv, bias, seqlen, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training): dropout_probability, is_training):
assert SelfFusedAttnFwdPrimitive.inner_primitive is not None assert SelfFusedAttnFwdPrimitive.inner_primitive is not None
cu_seqlen = generate_cu_seqlen(squeezed_mask) cu_seqlen = generate_cu_seqlen(seqlen)
output, softmax_aux, rng_state = SelfFusedAttnFwdPrimitive.inner_primitive.bind( output, softmax_aux, rng_state = SelfFusedAttnFwdPrimitive.inner_primitive.bind(
qkv, qkv,
...@@ -1859,10 +1856,9 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -1859,10 +1856,9 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
register_primitive(SelfFusedAttnFwdPrimitive) register_primitive(SelfFusedAttnFwdPrimitive)
def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, squeezed_mask: jnp.ndarray, def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, seqlen: jnp.ndarray, seed: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float, scaling_factor: float, dropout_probability: float, is_training: bool):
dropout_probability: float, is_training: bool):
""" """
Wrapper for TE self fused attention fwd Wrapper for TE self fused attention fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
...@@ -1875,7 +1871,7 @@ def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, squeezed_mask: jnp. ...@@ -1875,7 +1871,7 @@ def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, squeezed_mask: jnp.
bias = jnp.zeros(0, dtype=qkv.dtype) bias = jnp.zeros(0, dtype=qkv.dtype)
return SelfFusedAttnFwdPrimitive.outer_primitive.bind(qkv, return SelfFusedAttnFwdPrimitive.outer_primitive.bind(qkv,
bias, bias,
squeezed_mask, seqlen,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
...@@ -1896,14 +1892,14 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive): ...@@ -1896,14 +1892,14 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
@staticmethod @staticmethod
def abstract(qkv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, doutput_aval, def abstract(qkv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, doutput_aval,
mask_or_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor, seqlen_or_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training): dropout_probability, is_training):
""" """
Self fused attention bwd abstract Self fused attention bwd abstract
""" """
del softmax_aux_aval, rng_state_aval del softmax_aux_aval, rng_state_aval
# outer_primitve is squeezed_mask, inner_primitive is cu_seqlen # outer_primitve is seqlen, inner_primitive is cu_seqlen
del mask_or_cu_seqlen_aval, attn_bias_type, attn_mask_type del seqlen_or_cu_seqlen_aval, attn_bias_type, attn_mask_type
del scaling_factor, dropout_probability, is_training del scaling_factor, dropout_probability, is_training
qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype) qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
...@@ -1934,19 +1930,20 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive): ...@@ -1934,19 +1930,20 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability, batch, num_head, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training) dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out return out
@staticmethod @staticmethod
def impl(qkv, bias, softmax_aux, rng_state, output, doutput, squeezed_mask, attn_bias_type, def impl(qkv, bias, softmax_aux, rng_state, output, doutput, seqlen, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training): attn_mask_type, scaling_factor, dropout_probability, is_training):
assert SelfFusedAttnBwdPrimitive.inner_primitive is not None assert SelfFusedAttnBwdPrimitive.inner_primitive is not None
cu_seqlen = generate_cu_seqlen(squeezed_mask) cu_seqlen = generate_cu_seqlen(seqlen)
dqkv, dbias = SelfFusedAttnBwdPrimitive.inner_primitive.bind( dqkv, dbias = SelfFusedAttnBwdPrimitive.inner_primitive.bind(
qkv, qkv,
...@@ -2029,7 +2026,7 @@ register_primitive(SelfFusedAttnBwdPrimitive) ...@@ -2029,7 +2026,7 @@ register_primitive(SelfFusedAttnBwdPrimitive)
def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray, def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray,
rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray,
squeezed_mask: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float, attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool): dropout_probability: float, is_training: bool):
""" """
...@@ -2045,7 +2042,7 @@ def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.nd ...@@ -2045,7 +2042,7 @@ def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.nd
rng_state, rng_state,
output, output,
doutput, doutput,
squeezed_mask, seqlen,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
...@@ -2064,13 +2061,13 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2064,13 +2061,13 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract(q_aval, kv_aval, bias_aval, q_mask_or_cu_seqlen_aval, kv_mask_or_cu_seqlen_aval, def abstract(q_aval, kv_aval, bias_aval, q_seqlen_or_cu_seqlen_aval,
seed_aval, *, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type,
is_training): scaling_factor, dropout_probability, is_training):
""" """
Cross fused attention fwd abstract Cross fused attention fwd abstract
""" """
# outer_primitve is squeezed_mask, inner_primitive is cu_seqlen # outer_primitve is seqlen, inner_primitive is cu_seqlen
del scaling_factor, is_training del scaling_factor, is_training
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
...@@ -2083,18 +2080,17 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2083,18 +2080,17 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
assert q_dtype == kv_dtype == bias_dtype assert q_dtype == kv_dtype == bias_dtype
assert q_batch_shape == kv_batch_shape assert q_batch_shape == kv_batch_shape
assert q_num_head == kv_num_head
assert q_head_dim == kv_head_dim assert q_head_dim == kv_head_dim
assert nkv == 2 assert nkv == 2
assert q_mask_or_cu_seqlen_aval.dtype == kv_mask_or_cu_seqlen_aval.dtype assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
output_shape = q_aval.shape output_shape = q_aval.shape
output_dtype = q_dtype output_dtype = q_dtype
backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD, backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
attn_bias_type, attn_mask_type, dropout_probability, attn_bias_type, attn_mask_type, dropout_probability, q_num_head,
q_num_head, kv_num_head, kv_num_head, q_max_seqlen, kv_max_seqlen,
q_max_seqlen, kv_max_seqlen, q_head_dim).get_fused_attn_backend() q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, kv_max_seqlen) softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, kv_max_seqlen)
...@@ -2128,7 +2124,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2128,7 +2124,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
*batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape *batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
batch = reduce(operator.mul, batch_shape) batch = reduce(operator.mul, batch_shape)
kv_max_seqlen = kv_aval.shape[-4] kv_max_seqlen, kv_num_head = kv_aval.shape[-4], kv_aval.shape[-2]
operands = [q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed] operands = [q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed]
operand_shapes = map(lambda x: x.type.shape, operands) operand_shapes = map(lambda x: x.type.shape, operands)
...@@ -2139,7 +2135,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2139,7 +2135,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, batch, num_head, kv_num_head, q_max_seqlen, kv_max_seqlen, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training) jax_dtype_to_te_dtype(q_aval.dtype), is_training)
...@@ -2148,12 +2144,12 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2148,12 +2144,12 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
return out return out
@staticmethod @staticmethod
def impl(q, kv, bias, q_squeezed_mask, kv_squeezed_mask, seed, attn_bias_type, attn_mask_type, def impl(q, kv, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type, scaling_factor,
scaling_factor, dropout_probability, is_training): dropout_probability, is_training):
assert CrossFusedAttnFwdPrimitive.inner_primitive is not None assert CrossFusedAttnFwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask) q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
output, softmax_aux, rng_state = CrossFusedAttnFwdPrimitive.inner_primitive.bind( output, softmax_aux, rng_state = CrossFusedAttnFwdPrimitive.inner_primitive.bind(
q, q,
...@@ -2224,9 +2220,8 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2224,9 +2220,8 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
register_primitive(CrossFusedAttnFwdPrimitive) register_primitive(CrossFusedAttnFwdPrimitive)
def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, q_seqlen: jnp.ndarray,
q_squeezed_mask: jnp.ndarray, kv_squeezed_mask: jnp.ndarray, kv_seqlen: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float, attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool): dropout_probability: float, is_training: bool):
""" """
...@@ -2243,8 +2238,8 @@ def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, ...@@ -2243,8 +2238,8 @@ def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
return CrossFusedAttnFwdPrimitive.outer_primitive.bind(q, return CrossFusedAttnFwdPrimitive.outer_primitive.bind(q,
kv, kv,
bias, bias,
q_squeezed_mask, q_seqlen,
kv_squeezed_mask, kv_seqlen,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
...@@ -2296,7 +2291,7 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2296,7 +2291,7 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
*batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape *batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
batch = reduce(operator.mul, batch_shape) batch = reduce(operator.mul, batch_shape)
kv_max_seqlen = kv_aval.shape[-4] kv_max_seqlen, kv_num_head = kv_aval.shape[-4], kv_aval.shape[-2]
operands = [q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen] operands = [q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen]
operand_shapes = map(lambda x: x.type.shape, operands) operand_shapes = map(lambda x: x.type.shape, operands)
...@@ -2310,7 +2305,7 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2310,7 +2305,7 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
# the dropout elements are encoded in the forward auxiliary tensor # the dropout elements are encoded in the forward auxiliary tensor
# so seed is not needed in backward # so seed is not needed in backward
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, batch, num_head, kv_num_head, q_max_seqlen, kv_max_seqlen, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training) jax_dtype_to_te_dtype(q_aval.dtype), is_training)
...@@ -2319,13 +2314,12 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2319,13 +2314,12 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
return out return out
@staticmethod @staticmethod
def impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_squeezed_mask, def impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen,
kv_squeezed_mask, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training):
is_training):
assert CrossFusedAttnBwdPrimitive.inner_primitive is not None assert CrossFusedAttnBwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask) q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
dq, dkv, dbias = CrossFusedAttnBwdPrimitive.inner_primitive.bind( dq, dkv, dbias = CrossFusedAttnBwdPrimitive.inner_primitive.bind(
q, q,
...@@ -2417,10 +2411,9 @@ register_primitive(CrossFusedAttnBwdPrimitive) ...@@ -2417,10 +2411,9 @@ register_primitive(CrossFusedAttnBwdPrimitive)
def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray, softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray,
doutput: jnp.ndarray, q_squeezed_mask: jnp.ndarray, doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray,
kv_squeezed_mask: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float, scaling_factor: float, dropout_probability: float, is_training: bool):
dropout_probability: float, is_training: bool):
""" """
Wrapper for TE cross fused attention bwd Wrapper for TE cross fused attention bwd
Return the gradients of cross fused attention with packed kv input Return the gradients of cross fused attention with packed kv input
...@@ -2435,8 +2428,8 @@ def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, ...@@ -2435,8 +2428,8 @@ def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
rng_state, rng_state,
output, output,
doutput, doutput,
q_squeezed_mask, q_seqlen,
kv_squeezed_mask, kv_seqlen,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
......
...@@ -82,12 +82,12 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, ...@@ -82,12 +82,12 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch,
} }
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, size_t batch, size_t num_head, size_t num_gqa_groups, size_t q_max_seqlen, size_t kv_max_seqlen,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, DType dtype, bool is_training) { NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{batch, num_head, q_max_seqlen, kv_max_seqlen, return PackOpaque(CustomCallFusedAttnDescriptor{
head_dim, scaling_factor, dropout_probability, batch, num_head, num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, scaling_factor,
bias_type, mask_type, dtype, is_training}); dropout_probability, bias_type, mask_type, dtype, is_training});
} }
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream, void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
...@@ -745,8 +745,8 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, ...@@ -745,8 +745,8 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
size_t head_dim) { size_t head_dim) {
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_num_heads, kv_num_heads, mask_type, dropout_probability, q_num_heads, kv_num_heads, q_max_seqlen, kv_max_seqlen,
q_max_seqlen, kv_max_seqlen, head_dim); head_dim);
return backend; return backend;
} }
...@@ -768,6 +768,7 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu ...@@ -768,6 +768,7 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
auto batch = descriptor.batch; auto batch = descriptor.batch;
auto num_head = descriptor.num_head; auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen; auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen; auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim; auto head_dim = descriptor.head_dim;
...@@ -779,6 +780,9 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu ...@@ -779,6 +780,9 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention."); "q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
NVTE_CHECK(num_head == num_gqa_groups,
"num_head should be equal to num_gqa_groups in the qkvpacked attention");
auto dtype = descriptor.dtype; auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim}; auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen}; auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
...@@ -799,10 +803,10 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu ...@@ -799,10 +803,10 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
// aux tensors // aux tensors
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64); auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend( auto backend =
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type, nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype),
mask_type, dropout_probability, num_head, num_head, qkv_layout, bias_type, mask_type, dropout_probability, num_head,
q_max_seqlen, kv_max_seqlen, head_dim); num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
...@@ -853,6 +857,7 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -853,6 +857,7 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
auto batch = descriptor.batch; auto batch = descriptor.batch;
auto num_head = descriptor.num_head; auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen; auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen; auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim; auto head_dim = descriptor.head_dim;
...@@ -864,6 +869,9 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -864,6 +869,9 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention."); "q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
NVTE_CHECK(num_head == num_gqa_groups,
"num_head should be equal to num_gqa_groups in the qkvpacked attention");
auto dtype = descriptor.dtype; auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim}; auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim};
auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}; auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
...@@ -941,6 +949,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -941,6 +949,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
auto batch = descriptor.batch; auto batch = descriptor.batch;
auto num_head = descriptor.num_head; auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen; auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen; auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim; auto head_dim = descriptor.head_dim;
...@@ -951,7 +960,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -951,7 +960,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
auto dtype = descriptor.dtype; auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}; auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_head, head_dim}; auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen}; auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
// input tensors // input tensors
...@@ -976,10 +985,10 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -976,10 +985,10 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64); auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend( auto backend =
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type, nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype),
mask_type, dropout_probability, num_head, num_head, qkv_layout, bias_type, mask_type, dropout_probability, num_head,
q_max_seqlen, kv_max_seqlen, head_dim); num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
...@@ -1035,6 +1044,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa ...@@ -1035,6 +1044,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
auto batch = descriptor.batch; auto batch = descriptor.batch;
auto num_head = descriptor.num_head; auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen; auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen; auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim; auto head_dim = descriptor.head_dim;
...@@ -1045,7 +1055,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa ...@@ -1045,7 +1055,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
auto dtype = descriptor.dtype; auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}; auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_head, head_dim}; auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}; auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen}; auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
......
...@@ -98,6 +98,7 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, ...@@ -98,6 +98,7 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch,
struct CustomCallFusedAttnDescriptor { struct CustomCallFusedAttnDescriptor {
size_t batch; size_t batch;
size_t num_head; size_t num_head;
size_t num_gqa_groups;
size_t q_max_seqlen; size_t q_max_seqlen;
size_t kv_max_seqlen; size_t kv_max_seqlen;
size_t head_dim; size_t head_dim;
...@@ -110,8 +111,8 @@ struct CustomCallFusedAttnDescriptor { ...@@ -110,8 +111,8 @@ struct CustomCallFusedAttnDescriptor {
}; };
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, size_t batch, size_t num_head, size_t num_gqa_groups, size_t q_max_seqlen, size_t kv_max_seqlen,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, DType dtype, bool is_training); NVTE_Mask_Type mask_type, DType dtype, bool is_training);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
......
...@@ -16,7 +16,6 @@ import jax.numpy as jnp ...@@ -16,7 +16,6 @@ import jax.numpy as jnp
import numpy as np import numpy as np
from flax import linen as nn from flax import linen as nn
from flax.linen import partitioning as nn_partitioning from flax.linen import partitioning as nn_partitioning
from jax import dtypes
from jax import nn as jax_nn from jax import nn as jax_nn
from jax import random as jax_random from jax import random as jax_random
from jax import lax, vmap from jax import lax, vmap
...@@ -198,22 +197,31 @@ def core_attention(query: Array, ...@@ -198,22 +197,31 @@ def core_attention(query: Array,
batch_dim = 1 if transpose_batch_sequence else 0 batch_dim = 1 if transpose_batch_sequence else 0
assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], ( assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
'q, k, v batch dims must match.') 'q, k, v batch dims must match.')
assert query.shape[-2] == key.shape[-2] == value.shape[-2], ('q, k, v num_heads must match.')
sequence_dim = 0 if transpose_batch_sequence else 1 sequence_dim = 0 if transpose_batch_sequence else 1
assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.' assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.'
if float32_logits: if float32_logits:
query = query.astype(jnp.float32) query = query.astype(jnp.float32)
key = key.astype(jnp.float32) key = key.astype(jnp.float32)
h_q, h_kv = query.shape[-2], key.shape[-2]
assert (h_q % h_kv == 0) and (h_q >= h_kv)
group_size = h_q // h_kv
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
if transpose_batch_sequence: if transpose_batch_sequence:
attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key) attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
else: else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
attn_weights = checkpoint_name(attn_weights, 'logits') attn_weights = checkpoint_name(attn_weights, 'logits')
b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
attn_weights = _with_sharding_constraint(attn_weights, attn_weights = _with_sharding_constraint(attn_weights,
(BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES)) (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))
...@@ -229,6 +237,8 @@ def core_attention(query: Array, ...@@ -229,6 +237,8 @@ def core_attention(query: Array,
attn_weights = Softmax(softmax_type=softmax_type, attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype) scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype)
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
if not deterministic and dropout_rate > 0.: if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate keep_prob = 1.0 - dropout_rate
dropout_shape = list(attn_weights.shape) dropout_shape = list(attn_weights.shape)
...@@ -238,9 +248,9 @@ def core_attention(query: Array, ...@@ -238,9 +248,9 @@ def core_attention(query: Array,
attn_weights = attn_weights * multiplier attn_weights = attn_weights * multiplier
if transpose_batch_sequence: if transpose_batch_sequence:
return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value) return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
...@@ -262,6 +272,14 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -262,6 +272,14 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
The hidden dimension of each attention head. The hidden dimension of each attention head.
num_heads : int num_heads : int
The number of attention heads The number of attention heads
num_gqa_groups : int, default = `None`
Number of GQA groups. When `None` is present, it is equal to num_heads.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
dropout_rate : float, default = 0.0 dropout_rate : float, default = 0.0
Dropout probability for the dropout op during multi-head attention. Dropout probability for the dropout op during multi-head attention.
dropout_rng_name: str, default = 'dropout' dropout_rng_name: str, default = 'dropout'
...@@ -321,6 +339,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -321,6 +339,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
head_dim: int head_dim: int
num_heads: int num_heads: int
num_gqa_groups: int | None = None
dropout_rate: float = 0. dropout_rate: float = 0.
dropout_rng_name: str = 'dropout' dropout_rng_name: str = 'dropout'
layernorm_type: str = "layernorm" layernorm_type: str = "layernorm"
...@@ -342,6 +361,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -342,6 +361,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal') self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_heads
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -428,30 +449,22 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -428,30 +449,22 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
"supported attn_mask_type = {'causal', 'padding'}") "supported attn_mask_type = {'causal', 'padding'}")
is_self_attn = (inputs_q is inputs_kv) is_self_attn = (inputs_q is inputs_kv)
is_gqa = (self.num_heads != self.num_gqa_groups)
is_qkvpack = (is_self_attn and not is_gqa)
qkv_layout = QKVLayout.BS3HD if is_self_attn else QKVLayout.BSHD_BS2HD qkv_layout = QKVLayout.BS3HD if is_self_attn else QKVLayout.BSHD_BS2HD
attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type) attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
canonicalize_dtype = dtypes.canonicalize_dtype(self.dtype)
q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1] q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1]
kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1] kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1]
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0")) enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))
def _check_seqlen(seqlen):
return seqlen % 64 == 0
def _check_head_dim(head_dim):
return head_dim in [64, 128]
has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout, has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout,
attn_bias_type, attn_mask_type, attn_bias_type, attn_mask_type,
self.dropout_rate, self.dropout_rate, self.num_heads,
self.num_heads, self.num_heads, self.num_gqa_groups, q_seqlen,
q_seqlen, kv_seqlen, self.head_dim) kv_seqlen, self.head_dim)
use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \ use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \
_check_seqlen(q_seqlen) and _check_seqlen(kv_seqlen) and \
_check_head_dim(self.head_dim) and \
has_fused_attn_kernel and \ has_fused_attn_kernel and \
enable_fused_attn enable_fused_attn
...@@ -464,17 +477,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -464,17 +477,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
f"but got {self.transpose_batch_sequence}, " f"but got {self.transpose_batch_sequence}, "
if not self.fuse_qkv: if not self.fuse_qkv:
reason += f"fuse_qkv=True is required but got {self.fuse_qkv}, " reason += f"fuse_qkv=True is required but got {self.fuse_qkv}, "
if canonicalize_dtype not in [jnp.bfloat16, jnp.float16]:
reason += f"dtype in [BF16, FP16] is required " \
f"but got dtype={canonicalize_dtype}, "
if not _check_seqlen(q_seqlen):
reason += f"q_seqlen % 64 == 0 is required " \
f"but got {q_seqlen=}, "
if not _check_seqlen(kv_seqlen):
reason += f"kv_seqlen % 64 == 0 is required " \
f"but got {kv_seqlen=}, "
if not _check_head_dim(self.head_dim):
reason += f"head_dim should be 64 or 128 but got {self.head_dim}, "
if not has_fused_attn_kernel: if not has_fused_attn_kernel:
reason += "no fused attention kernel is available, " reason += "no fused attention kernel is available, "
...@@ -484,7 +486,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -484,7 +486,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
residual = inputs_q residual = inputs_q
if self.fuse_qkv: if self.fuse_qkv:
if is_self_attn: if is_qkvpack:
qkv_proj, ln_out = LayerNormDenseGeneral( qkv_proj, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm, enable_layernorm=not self.output_layernorm,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
...@@ -515,7 +517,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -515,7 +517,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
axis=-1, axis=-1,
features=self.num_heads * self.head_dim, features=self.num_heads * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm, return_layernorm_output=(self.apply_residual_connection_post_layernorm
or is_self_attn),
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes=(W_FSDP_AXES, W_TP_AXES), kernel_axes=(W_FSDP_AXES, W_TP_AXES),
...@@ -525,8 +528,13 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -525,8 +528,13 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
dtype=self.dtype, dtype=self.dtype,
kernel_init=query_init, kernel_init=query_init,
name='query')(inputs_q) name='query')(inputs_q)
if is_self_attn:
assert ln_out is not None
inputs_kv = ln_out
kv_proj = DenseGeneral(axis=-1, kv_proj = DenseGeneral(axis=-1,
features=(2, self.num_heads * self.head_dim), features=(2, self.num_gqa_groups * self.head_dim),
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_init=kv_init, kernel_init=kv_init,
...@@ -542,7 +550,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -542,7 +550,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kv_projection = functools.partial( kv_projection = functools.partial(
DenseGeneral, DenseGeneral,
axis=-1, axis=-1,
features=self.num_heads * self.head_dim, features=self.num_gqa_groups * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_TP_AXES), kernel_axes=(W_FSDP_AXES, W_TP_AXES),
use_bias=self.use_bias, use_bias=self.use_bias,
...@@ -583,9 +591,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -583,9 +591,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
query = checkpoint_name(query, 'query_proj') query = checkpoint_name(query, 'query_proj')
key = checkpoint_name(key, 'key_proj') key = checkpoint_name(key, 'key_proj')
value = checkpoint_name(value, 'value_proj') value = checkpoint_name(value, 'value_proj')
query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim)) query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim)) key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim)) value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
qkv_sharding_constraint = \ qkv_sharding_constraint = \
(SEQLEN_AXES, BATCH_AXES, HEAD_AXES, HIDDEN_AXES) \ (SEQLEN_AXES, BATCH_AXES, HEAD_AXES, HIDDEN_AXES) \
if self.transpose_batch_sequence \ if self.transpose_batch_sequence \
...@@ -650,7 +658,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -650,7 +658,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
# ensure the old key never used # ensure the old key never used
del dropout_rng del dropout_rng
if is_self_attn: if is_qkvpack:
qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim)) qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES, qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
HIDDEN_AXES) HIDDEN_AXES)
...@@ -667,7 +675,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -667,7 +675,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
else: else:
assert bias is None assert bias is None
query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim)) query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim))
kv_proj = kv_proj.reshape((*kv_proj.shape[:-1], self.num_heads, self.head_dim)) kv_proj = kv_proj.reshape((*kv_proj.shape[:-1], self.num_gqa_groups, self.head_dim))
q_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES) q_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
kv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES, kv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
HIDDEN_AXES) HIDDEN_AXES)
...@@ -865,6 +873,14 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -865,6 +873,14 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Intermediate size to which input samples are projected. Intermediate size to which input samples are projected.
num_attention_heads: int, default = 8 num_attention_heads: int, default = 8
Number of attention heads in the transformer layer. Number of attention heads in the transformer layer.
num_gqa_groups : int, default = `None`
Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm' layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
Indicate the type of layer normalization. Indicate the type of layer normalization.
layernorm_epsilon: float, default = 1e-6 layernorm_epsilon: float, default = 1e-6
...@@ -961,6 +977,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -961,6 +977,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
hidden_size: int = 512 hidden_size: int = 512
mlp_hidden_size: int = 2048 mlp_hidden_size: int = 2048
num_attention_heads: int = 8 num_attention_heads: int = 8
num_gqa_groups: int | None = None
layernorm_type: str = 'layernorm' layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6 layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
...@@ -995,6 +1012,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -995,6 +1012,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
if self.mlp_kernel_init is None: if self.mlp_kernel_init is None:
self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in',
'truncated_normal') 'truncated_normal')
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -1091,6 +1110,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1091,6 +1110,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
dtype=self.dtype, dtype=self.dtype,
head_dim=head_dim, head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.attention_dropout, dropout_rate=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name, dropout_rng_name=self.dropout_rng_name,
...@@ -1141,6 +1161,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1141,6 +1161,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
dtype=self.dtype, dtype=self.dtype,
head_dim=head_dim, head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.attention_dropout, dropout_rate=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name, dropout_rng_name=self.dropout_rng_name,
......
...@@ -40,8 +40,8 @@ class QKVLayout(Enum): ...@@ -40,8 +40,8 @@ class QKVLayout(Enum):
def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type, def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type,
dropout_probability, num_heads_q, num_heads_kv, dropout_probability, num_heads_q, num_heads_kv, max_seqlen_q,
max_seqlen_q, max_seqlen_kv, head_dim): max_seqlen_kv, head_dim):
""" """
To check whether the fused attention kernel is available To check whether the fused attention kernel is available
""" """
...@@ -83,10 +83,11 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda ...@@ -83,10 +83,11 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda
seed: jnp.ndarray, attn_bias_type: AttnBiasType, seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float, attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool): dropout_probability: float, is_training: bool):
squeezed_mask = mask[..., 0] mask = jnp.logical_not(mask)
actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
output, softmax_aux, rng_state = self_fused_attn_fwd(qkv, output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
bias, bias,
squeezed_mask, actual_seqlen,
seed, seed,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
...@@ -96,12 +97,12 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda ...@@ -96,12 +97,12 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda
output = checkpoint_name(output, 'context') output = checkpoint_name(output, 'context')
softmax_aux = checkpoint_name(softmax_aux, 'context') softmax_aux = checkpoint_name(softmax_aux, 'context')
rng_state = checkpoint_name(rng_state, 'context') rng_state = checkpoint_name(rng_state, 'context')
return output, (qkv, bias, softmax_aux, rng_state, output, squeezed_mask) return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen)
def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz): is_training, ctx, dz):
qkv, bias, softmax_aux, rng_state, output, squeezed_mask = ctx qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx
grad_qkv, grad_bias = self_fused_attn_bwd(qkv, grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
bias, bias,
...@@ -109,7 +110,7 @@ def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dr ...@@ -109,7 +110,7 @@ def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dr
rng_state, rng_state,
output, output,
dz, dz,
squeezed_mask, actual_seqlen,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
...@@ -159,14 +160,19 @@ def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: ...@@ -159,14 +160,19 @@ def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask:
def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type, def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training): scaling_factor, dropout_probability, is_training):
q_squeezed_mask = mask[..., 0] mask = jnp.logical_not(mask)
kv_squeezed_mask = mask[..., 0, :] q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else:
# When mask is padding + causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = cross_fused_attn_fwd(q, output, softmax_aux, rng_state = cross_fused_attn_fwd(q,
kv, kv,
bias, bias,
q_squeezed_mask, q_actual_seqlen,
kv_squeezed_mask, kv_actual_seqlen,
seed, seed,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
...@@ -174,12 +180,12 @@ def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mas ...@@ -174,12 +180,12 @@ def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mas
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return output, (q, kv, bias, softmax_aux, rng_state, output, q_squeezed_mask, kv_squeezed_mask) return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen)
def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz): is_training, ctx, dz):
q, kv, bias, softmax_aux, rng_state, output, q_squeezed_mask, kv_squeezed_mask = ctx q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
grad_q, grad_kv, grad_bias = cross_fused_attn_bwd(q, grad_q, grad_kv, grad_bias = cross_fused_attn_bwd(q,
kv, kv,
...@@ -188,8 +194,8 @@ def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, d ...@@ -188,8 +194,8 @@ def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, d
rng_state, rng_state,
output, output,
dz, dz,
q_squeezed_mask, q_actual_seqlen,
kv_squeezed_mask, kv_actual_seqlen,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
......
...@@ -64,6 +64,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer): ...@@ -64,6 +64,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
head_dim: int = 64 head_dim: int = 64
num_heads: int = 16 num_heads: int = 16
num_gqa_groups: int | None = None
dropout_rate: float = 0. dropout_rate: float = 0.
dropout_rng_name: str = 'dropout' dropout_rng_name: str = 'dropout'
layernorm_type: str = "layernorm" layernorm_type: str = "layernorm"
...@@ -80,6 +81,11 @@ class MultiHeadAttention(TransformerEngineBaseLayer): ...@@ -80,6 +81,11 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
scaled_query_init: bool = True scaled_query_init: bool = True
float32_logits: bool = False float32_logits: bool = False
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_heads
super().__post_init__()
def setup(self) -> None: def setup(self) -> None:
"""setup""" """setup"""
super().setup() super().setup()
...@@ -89,6 +95,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer): ...@@ -89,6 +95,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
dtype=self.dtype, dtype=self.dtype,
head_dim=self.head_dim, head_dim=self.head_dim,
num_heads=self.num_heads, num_heads=self.num_heads,
num_gqa_groups=self.num_gqa_groups,
dropout_rate=self.dropout_rate, dropout_rate=self.dropout_rate,
dropout_rng_name=self.dropout_rng_name, dropout_rng_name=self.dropout_rng_name,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
...@@ -131,6 +138,7 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -131,6 +138,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
hidden_size: int = 512 hidden_size: int = 512
mlp_hidden_size: int = 2048 mlp_hidden_size: int = 2048
num_attention_heads: int = 8 num_attention_heads: int = 8
num_gqa_groups: int | None = None
layernorm_type: str = 'layernorm' layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6 layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
...@@ -156,6 +164,11 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -156,6 +164,11 @@ class TransformerLayer(TransformerEngineBaseLayer):
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
def setup(self) -> None: def setup(self) -> None:
"""setup""" """setup"""
super().setup() super().setup()
...@@ -186,6 +199,7 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -186,6 +199,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
mlp_hidden_size=self.mlp_hidden_size, mlp_hidden_size=self.mlp_hidden_size,
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon, layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment