Unverified Commit 8f6c5248 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX][Common] Support GQA (#578)



* Support num_gqa_groups arguments
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add GQA support on the JAX bridge code
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix the kv stride of the arbitrary backend
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Complete rewrite fused attention tests and add GQA coverage
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Support unfused GQA
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Calculate seqlen before the primitive for the better perf
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add GQA layer tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Apply code style checks for te_jax
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Apply code style checks for tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add num_gqa_groups doc
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine the qkv_type
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Correct the variable naming
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Handle Max512 CAUSAL
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add WAR for the latest jax image
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent daad219f
......@@ -4,6 +4,9 @@
set -xe
# WAR(rewang) for the "Check failed: reduction_kind.has_value()"
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_xla_runtime_executable=true"
: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_*
......@@ -14,5 +14,7 @@ pytest -Wignore -v $TE_PATH/examples/jax/mnist
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
# WAR(rewang) for the "Check failed: reduction_kind.has_value()"
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_xla_runtime_executable=true"
pytest -Wignore -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
This diff is collapsed.
......@@ -9,13 +9,14 @@ import jax
import jax.numpy as jnp
import pytest
from transformer_engine.common.recipe import Format
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from utils import assert_allclose
from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer
from transformer_engine.common.recipe import Format
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
is_fp8_supported, reason = is_fp8_available()
......@@ -85,8 +86,13 @@ _KEY_OF_LAYERNORM_TYPE = 'layernorm_type'
_KEY_OF_ZERO_CENTERED_GAMMA = 'zero_centered_gamma'
_KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence'
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
_KEY_OF_NUM_HEADS = 'num_attention_heads'
_KEY_OF_NUM_GQA_GROUPS = 'num_gqa_groups'
BASE_ATTRS = {_KEY_OF_TRANSPOSE_BS: True}
BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_NUM_HEADS: 8,
}
ATTRS = [{
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
......@@ -129,6 +135,9 @@ ATTRS = [{
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4
}]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
......@@ -137,21 +146,13 @@ ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
class TestEncoderLayer:
@staticmethod
def sync_params(ref, target, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
def sync_params(ref, target):
unfreeze_target = flax.core.unfreeze(target)
if fuse_qkv:
unfreeze_target['attention']['qkv']['kernel'] = \
jnp.reshape(ref['attention']['qkv']['kernel'],
unfreeze_target['attention']['qkv']['kernel'].shape)
else:
unfreeze_target['attention']['query']['kernel'] = \
ref['attention']['query']['kernel']
unfreeze_target['attention']['key']['kernel'] = \
ref['attention']['key']['kernel']
unfreeze_target['attention']['value']['kernel'] = \
ref['attention']['value']['kernel']
unfreeze_attn_scope = unfreeze_target['attention']
ref_attn_scope = ref['attention']
for key in ref_attn_scope.keys():
unfreeze_attn_scope[key]['kernel'] = \
ref_attn_scope[key]['kernel'].reshape(unfreeze_attn_scope[key]['kernel'].shape)
unfreeze_target['mlp']['wi_kernel'] = \
jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
unfreeze_target['mlp']['wo_kernel'] = \
......@@ -196,7 +197,7 @@ class TestEncoderLayer:
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks)
ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params, attrs)
ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params)
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)
......@@ -242,7 +243,7 @@ class TestEncoderLayer:
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks)
ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params, attrs)
ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params)
if FP8Helper.is_fp8_enabled():
for _ in range(4):
......@@ -266,7 +267,10 @@ class TestEncoderLayer:
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
def reorganize_test_wgrad(test_wgrad, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
num_heads = attrs.get(_KEY_OF_NUM_HEADS)
num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) and \
num_heads == num_gqa_groups
attn_name = 'attention'
unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad)
......@@ -280,10 +284,12 @@ class TestEncoderLayer:
unfreeze_test_wgrad['pre_attention_layer_norm']['ln_bias'] = \
unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
if fuse_qkv:
unfreeze_test_wgrad[attn_name]['qkv']['kernel'] = \
jnp.reshape(unfreeze_test_wgrad[attn_name]['qkv']['kernel'],
(unfreeze_test_wgrad[attn_name]['qkv']['kernel'].shape[0], -1))
for key in unfreeze_test_wgrad[attn_name].keys():
unfreeze_test_wgrad[attn_name][key]['kernel'] = \
jnp.reshape(unfreeze_test_wgrad[attn_name][key]['kernel'],
(unfreeze_test_wgrad[attn_name][key]['kernel'].shape[0], -1))
unfreeze_test_wgrad['pre_mlp_layer_norm'] = {}
unfreeze_test_wgrad['pre_mlp_layer_norm']['scale'] = \
unfreeze_test_wgrad['mlp']['scale']
......@@ -348,26 +354,14 @@ class TestEncoderLayer:
class TestDecoderLayer:
@staticmethod
def sync_params(ref, target, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
def sync_params(ref, target):
unfreeze_target = flax.core.unfreeze(target)
if fuse_qkv:
unfreeze_target['self_attention']['qkv']['kernel'] = \
jnp.reshape(ref['self_attention']['qkv']['kernel'],
unfreeze_target['self_attention']['qkv']['kernel'].shape)
unfreeze_target['encoder_decoder_attention']['kv']['kernel'] = \
jnp.reshape(ref['encoder_decoder_attention']['kv']['kernel'],
unfreeze_target['encoder_decoder_attention']['kv']['kernel'].shape)
else:
unfreeze_target['self_attention']['query']['kernel'] = \
ref['self_attention']['query']['kernel']
unfreeze_target['self_attention']['key']['kernel'] = \
ref['self_attention']['key']['kernel']
unfreeze_target['self_attention']['value']['kernel'] = \
ref['self_attention']['value']['kernel']
unfreeze_target['encoder_decoder_attention']['query']['kernel'] = \
ref['encoder_decoder_attention']['query']['kernel']
for scope in ['self_attention', 'encoder_decoder_attention']:
unfreeze_scope = unfreeze_target[scope]
ref_scope = ref[scope]
for key in unfreeze_scope.keys():
unfreeze_scope[key]['kernel'] = \
ref_scope[key]['kernel'].reshape(unfreeze_scope[key]['kernel'].shape)
unfreeze_target['mlp']['wi_kernel'] = \
jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
unfreeze_target['mlp']['wo_kernel'] = \
......@@ -412,7 +406,7 @@ class TestDecoderLayer:
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks)
ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params, attrs)
ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params)
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)
......@@ -459,7 +453,7 @@ class TestDecoderLayer:
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks)
ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params, attrs)
ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params)
if FP8Helper.is_fp8_enabled():
for _ in range(4):
......@@ -483,11 +477,14 @@ class TestDecoderLayer:
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
def reorganize_test_wgrad(test_wgrad, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
attn_name = 'self_attention'
num_heads = attrs.get(_KEY_OF_NUM_HEADS)
num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) and \
num_heads == num_gqa_groups
unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad)
if "output_layernorm" not in attrs:
attn_name = 'self_attention'
unfreeze_test_wgrad['pre_self_attention_layer_norm'] = {}
pre_attn_layer_key = 'qkv' if fuse_qkv else 'query'
unfreeze_test_wgrad['pre_self_attention_layer_norm']['scale'] = \
......@@ -498,14 +495,11 @@ class TestDecoderLayer:
unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
if fuse_qkv:
unfreeze_test_wgrad[attn_name]['qkv']['kernel'] = \
jnp.reshape(unfreeze_test_wgrad[attn_name]['qkv']['kernel'],
(unfreeze_test_wgrad[attn_name]['qkv']['kernel'].shape[0], -1))
attn_name = 'encoder_decoder_attention'
unfreeze_test_wgrad[attn_name]['kv']['kernel'] = \
jnp.reshape(unfreeze_test_wgrad[attn_name]['kv']['kernel'],
(unfreeze_test_wgrad[attn_name]['kv']['kernel'].shape[0], -1))
for scope in ['self_attention', 'encoder_decoder_attention']:
for key in unfreeze_test_wgrad[scope].keys():
unfreeze_test_wgrad[scope][key]['kernel'] = \
jnp.reshape(unfreeze_test_wgrad[scope][key]['kernel'],
(unfreeze_test_wgrad[scope][key]['kernel'].shape[0], -1))
unfreeze_test_wgrad['pre_cross_attention_layer_norm'] = {}
unfreeze_test_wgrad['pre_cross_attention_layer_norm']['scale'] = \
......
......@@ -12,6 +12,8 @@ from praxis import pax_fiddle
from praxis.base_layer import WeightInit, DEFAULT_INIT_MUTABLE_LIST
import pytest
from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax import fp8_autocast, update_fp8_metas, update_collections
from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral
......@@ -23,12 +25,12 @@ from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLaye
from transformer_engine.jax.flax.module import Softmax
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from transformer_engine.jax.praxis import LayerNorm
from transformer_engine.jax.praxis import FusedSoftmax, LayerNorm
from transformer_engine.jax.praxis import FusedSoftmax
from transformer_engine.jax.praxis import LayerNormLinear, LayerNormMLP, Linear
from transformer_engine.jax.praxis import MultiHeadAttention, RelativePositionBiases
from transformer_engine.jax.praxis import TransformerEngineBaseLayer, TransformerLayer, TransformerLayerType
from transformer_engine.jax.praxis import TransformerEngineBaseLayer
from transformer_engine.jax.praxis import TransformerLayer, TransformerLayerType
from transformer_engine.jax.softmax import SoftmaxType
from utils import assert_allclose
is_fp8_supported, reason = is_fp8_available()
......@@ -662,7 +664,7 @@ class TestRelativePositionBias(TestLayer):
praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)
praxis_loss= \
TestLayer.loss(praxis_variables, *test_input, module=praxis_layer, mean_out=False)
TestLayer.loss(praxis_variables, *test_input, module=praxis_layer, mean_out=False)
flax_loss = \
TestLayer.loss(flax_variables, *test_input, module=flax_layer, mean_out=False)
......@@ -674,6 +676,8 @@ class MultiHeadAttnAttr:
LN_TYPE = 'layernorm_type'
ATTN_MASK_TYPE = 'attn_mask_type'
ZERO_CEN = 'zero_centered_gamma'
NUM_ATTN_HEADS = 'num_attention_heads'
NUM_GQA_GROUPS = 'num_gqa_groups'
ATTRS = [{
USE_BIAS: True,
LN_TYPE: 'layernorm',
......@@ -704,6 +708,13 @@ class MultiHeadAttnAttr:
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ATTN_MASK_TYPE: 'causal'
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal'
}]
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utility for the TE layer tests"""
import functools
import math
......@@ -28,6 +29,9 @@ Initializer = Callable[[PRNGKey, Shape, DType], Array]
def is_devices_enough(required):
"""
Check if the available GPUs is enough
"""
return len(jax.devices()) >= required
......@@ -121,9 +125,9 @@ def dot_product_attention(query: Array,
query: queries for calculating attention with shape of `[batch, q_length,
num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of `[batch, kv_length,
num_heads, qk_depth_per_head]`.
num_gqa_groups, qk_depth_per_head]`.
value: values to be used in attention with shape of `[batch, kv_length,
num_heads, v_depth_per_head]`.
num_gqa_groups, v_depth_per_head]`.
bias: bias for the attention weights. This should be broadcastable to the
shape `[batch, num_heads, q_length, kv_length]` This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
......@@ -141,21 +145,31 @@ def dot_product_attention(query: Array,
batch_dim = 1 if transpose_batch_sequence else 0
assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
'q, k, v batch dims must match.')
assert query.shape[-2] == key.shape[-2] == value.shape[-2], ('q, k, v num_heads must match.')
sequence_dim = 0 if transpose_batch_sequence else 1
assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.'
# Casting logits and softmax computation for float32 for model stability.
if float32_logits:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
# `attn_weights`: [batch, num_heads, q_length, kv_length]
# `attn_weights`: [batch, num_heads, groups, q_length, kv_length]
h_q, h_kv = query.shape[-2], key.shape[-2]
assert (h_q % h_kv == 0) and (h_q >= h_kv)
group_size = h_q // h_kv
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
if transpose_batch_sequence:
attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key)
attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
# reshape back to normal DPA shape for bias/softmax/dropout
b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
# Apply attention bias: masking, dropout, proximity bias, etc.
if bias is not None:
......@@ -174,11 +188,13 @@ def dot_product_attention(query: Array,
multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
attn_weights = attn_weights * multiplier
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
# Take the linear combination of `value`.
if transpose_batch_sequence:
return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value)
return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
class DenseGeneral(nn.Module):
......@@ -235,7 +251,8 @@ class DenseGeneral(nn.Module):
if self.use_bias:
bias = nn_partitioning.param_with_axes('bias',
self.bias_init, (self.features,),
self.bias_init,
self.features,
self.dtype,
axes=self.bias_axes)
else:
......@@ -332,6 +349,7 @@ class MultiHeadAttention(nn.Module):
Attributes:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
num_gqa_groups: number of kv attention heads
head_dim: dimension of each head.
dtype: the dtype of the computation.
dropout_rate: dropout rate
......@@ -340,9 +358,10 @@ class MultiHeadAttention(nn.Module):
numerical issues with bfloat16.
"""
num_heads: int
head_dim: int
transpose_batch_sequence: bool
num_heads: int = 8
num_gqa_groups: int | None = None
head_dim: int = 64
transpose_batch_sequence: bool = True
dtype: DType = jnp.float32
dropout_rate: float = 0.
kernel_init: Initializer = None
......@@ -354,6 +373,8 @@ class MultiHeadAttention(nn.Module):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact
......@@ -393,18 +414,24 @@ class MultiHeadAttention(nn.Module):
Returns:
output of shape `[batch, length, q_features]`.
"""
projection = functools.partial(DenseGeneral,
axis=-1,
features=self.num_heads * self.head_dim,
kernel_axes=('embed', 'joined_kv'),
dtype=self.dtype)
q_projection = functools.partial(DenseGeneral,
axis=-1,
features=self.num_heads * self.head_dim,
kernel_axes=('embed', 'joined_kv'),
dtype=self.dtype)
kv_projection = functools.partial(DenseGeneral,
axis=-1,
features=self.num_gqa_groups * self.head_dim,
kernel_axes=('embed', 'joined_kv'),
dtype=self.dtype)
# NOTE: T5 does not explicitly rescale the attention logits by
# 1/sqrt(depth_kq)! This is folded into the initializers of the
# linear transformations, which is equivalent under Adafactor
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
query_init = lambda *args: self.kernel_init(*args) / ( # pylint: disable=unnecessary-lambda-assignment
depth_scaling if self.scaled_query_init else 1.0)
query_init = lambda *args: self.kernel_init(*args) / (depth_scaling
if self.scaled_query_init else 1.0)
# Project inputs_q to multi-headed q/k/v
# dimensions are then [batch, length, num_heads, head_dim]
......@@ -417,13 +444,17 @@ class MultiHeadAttention(nn.Module):
v_shape = (shape[0], shape[1] // 3)
q_kernel = query_init(key, q_shape, dtype)
k_kernel = self.kernel_init(key, k_shape, dtype) # pylint: disable=too-many-function-args
v_kernel = self.kernel_init(key, v_shape, dtype) # pylint: disable=too-many-function-args
k_kernel = self.kernel_init(key, k_shape, dtype)
v_kernel = self.kernel_init(key, v_shape, dtype)
return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype)
is_self_attn = (inputs_q is inputs_kv)
is_gqa = (self.num_heads != self.num_gqa_groups)
is_qkvpack = (is_self_attn and not is_gqa)
if self.fuse_qkv:
if inputs_q is inputs_kv:
if is_qkvpack:
qkv_proj = DenseGeneral(axis=-1,
features=self.num_heads * self.head_dim * 3,
kernel_axes=('embed', 'joined_kv'),
......@@ -436,24 +467,24 @@ class MultiHeadAttention(nn.Module):
if self.scale_attn_logits:
query = query / depth_scaling
else:
query = projection(kernel_init=query_init, name='query')( \
query = q_projection(kernel_init=query_init, name='query')( \
(inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q)
kv_proj = DenseGeneral(axis=-1,
features=self.num_heads * self.head_dim * 2,
features=self.num_gqa_groups * self.head_dim * 2,
kernel_axes=('embed', 'joined_kv'),
kernel_init=self.kernel_init,
name='kv',
dtype=self.dtype)(inputs_kv)
key, value = jnp.split(kv_proj, [self.num_heads * self.head_dim], axis=-1)
key, value = jnp.split(kv_proj, [self.num_gqa_groups * self.head_dim], axis=-1)
else:
query = projection(kernel_init=query_init, name='query')( \
query = q_projection(kernel_init=query_init, name='query')( \
(inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q)
key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim))
key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim))
value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim))
query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
if self.transpose_batch_sequence:
query = nn_partitioning.with_sharding_constraint(query,
......@@ -476,7 +507,7 @@ class MultiHeadAttention(nn.Module):
# fusion optimization. This also enables the "scatter via one-hot
# broadcast" trick, which means we do a one-hot broadcast instead of a
# scatter/gather operations, resulting in a 3-4x speedup in practice.
swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3]) # pylint: disable=unnecessary-lambda-assignment
swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3])
cached_key = self.variable('cache', 'cached_key', jnp.zeros, swap_dims(key.shape),
key.dtype)
cached_value = self.variable('cache', 'cached_value', jnp.zeros, swap_dims(value.shape),
......@@ -755,7 +786,8 @@ class RelativePositionBiases(nn.Module):
class EncoderLayer(nn.Module):
"""Transformer encoder layer."""
relative_embedding: nn.Module = None
num_heads: int = 8
num_attention_heads: int = 8
num_gqa_groups: int | None = None
head_dim: int = 64
dropout_rate: float = 0.1
transpose_batch_sequence: bool = True
......@@ -773,6 +805,11 @@ class EncoderLayer(nn.Module):
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact
def __call__(self, inputs, encoder_mask=None, deterministic=False):
# Relative position embedding as attention biases.
......@@ -782,7 +819,7 @@ class EncoderLayer(nn.Module):
if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32,
max_distance=128,
num_heads=self.num_heads,
num_heads=self.num_attention_heads,
dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling(
1.0, 'fan_avg', 'uniform'),
......@@ -807,7 +844,8 @@ class EncoderLayer(nn.Module):
x = inputs
# [batch, length, emb_dim] -> [batch, length, emb_dim]
x = MultiHeadAttention(num_heads=self.num_heads,
x = MultiHeadAttention(num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
......@@ -868,7 +906,8 @@ class EncoderLayer(nn.Module):
class DecoderLayer(nn.Module):
"""Transformer decoder layer that attends to the encoder."""
relative_embedding: nn.Module = None
num_heads: int = 8
num_attention_heads: int = 8
num_gqa_groups: int | None = None
head_dim: int = 64
dropout_rate: float = 0.1
transpose_batch_sequence: bool = True
......@@ -886,6 +925,11 @@ class DecoderLayer(nn.Module):
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact
def __call__(self,
inputs,
......@@ -903,7 +947,7 @@ class DecoderLayer(nn.Module):
if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32,
max_distance=128,
num_heads=self.num_heads,
num_heads=self.num_attention_heads,
dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling(
1.0, 'fan_avg', 'uniform'),
......@@ -928,7 +972,8 @@ class DecoderLayer(nn.Module):
x = inputs
# Self-attention block
x = MultiHeadAttention(num_heads=self.num_heads,
x = MultiHeadAttention(num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
......@@ -960,7 +1005,8 @@ class DecoderLayer(nn.Module):
if self.apply_residual_connection_post_layernorm:
residual = y
y = MultiHeadAttention(num_heads=self.num_heads,
y = MultiHeadAttention(num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
......@@ -1012,6 +1058,9 @@ class DecoderLayer(nn.Module):
def make_causal_mask(batch, seqlen, dtype=jnp.uint8):
"""
Generate causal mask
"""
shape = (batch, seqlen)
idxs = jnp.broadcast_to(jnp.arange(shape[-1], dtype=jnp.int32), shape)
......@@ -1022,6 +1071,9 @@ def make_causal_mask(batch, seqlen, dtype=jnp.uint8):
def make_self_mask(batch, seqlen, dtype=jnp.uint8):
"""
Generate attention mask
"""
shape = (batch, seqlen)
mask = jnp.ones((*shape, shape[-1]))
mask = jnp.expand_dims(mask, axis=-3)
......@@ -1057,7 +1109,7 @@ def assert_allclose(
dtype = actual.dtype
# Determine tolerances
tols = dict()
tols = {}
if rtol is None or atol is None:
tols = dtype_tols(dtype)
if rtol is not None:
......
......@@ -573,14 +573,14 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_QKV->data.dtype;
const auto QKV_type = input_QKV->data.dtype;
void *devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0;
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = 2 * num_attn_heads * head_dim;
stride = typeToSize(QKV_type) * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = 2 * head_dim;
stride = typeToSize(QKV_type) * head_dim;
}
void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
......@@ -677,14 +677,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_QKV->data.dtype;
void *devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0;
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = 2 * num_attn_heads * head_dim;
stride = typeToSize(QKV_type) * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = 2 * head_dim;
stride = typeToSize(QKV_type) * head_dim;
}
void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
......@@ -712,7 +713,6 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const auto qkv_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_attn_heads,
......@@ -723,7 +723,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlens, devPtrCuSeqlens,
get_cudnn_fe_dtype(qkv_type), workspace->data.dptr,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
if (workspace_size > 0) {
......@@ -750,15 +750,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0;
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = 2 * num_attn_heads * head_dim;
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = 2 * head_dim;
stride = typeToSize(QKV_type) * head_dim;
}
void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
......@@ -860,15 +860,14 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0;
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = 2 * num_attn_heads * head_dim;
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = 2 * head_dim;
stride = typeToSize(QKV_type) * head_dim;
}
void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
......@@ -935,7 +934,7 @@ void fused_attn_arbitrary_seqlen_fwd(
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_K->data.dptr;
void *devPtrV = input_V->data.dptr;
......
......@@ -1651,14 +1651,11 @@ class FusedAttnHelper:
def get_fused_attn_backend(self):
"""Get the fused attention kernel backend"""
return transformer_engine_jax.get_fused_attn_backend(jax_dtype_to_te_dtype(self.q_type),
jax_dtype_to_te_dtype(self.kv_type),
self.qkv_layout, self.attn_bias_type,
self.attn_mask_type,
self.dropout_probability,
self.num_heads_q, self.num_heads_kv,
self.max_seqlen_q, self.max_seqlen_kv,
self.head_dim)
return transformer_engine_jax.get_fused_attn_backend(
jax_dtype_to_te_dtype(self.q_type), jax_dtype_to_te_dtype(self.kv_type),
self.qkv_layout, self.attn_bias_type, self.attn_mask_type, self.dropout_probability,
self.num_heads_q, self.num_heads_kv, self.max_seqlen_q, self.max_seqlen_kv,
self.head_dim)
@dataclass(frozen=True)
......@@ -1701,12 +1698,11 @@ class _FusedAttnRNGStateChecker:
return seed
def generate_cu_seqlen(mask):
def generate_cu_seqlen(actual_seqlen):
"""
Generating cumsum seqlen for a batch
"""
seqlen = jnp.sum(mask == 0, axis=(-1, -2), dtype=jnp.int32)
cu_seqlen = jnp.cumsum(seqlen)
cu_seqlen = jnp.cumsum(actual_seqlen)
cu_seqlen = jnp.hstack((0, cu_seqlen))
return cu_seqlen
......@@ -1722,13 +1718,13 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
outer_primitive = None
@staticmethod
def abstract(qkv_aval, bias_aval, mask_or_cu_seqlen_aval, seed_aval, *, attn_bias_type,
def abstract(qkv_aval, bias_aval, seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Self fused attention fwd abstract
"""
# outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
del mask_or_cu_seqlen_aval, scaling_factor, is_training
# outer_primitve is seqlen, inner_primitive is cu_seqlen
del seqlen_or_cu_seqlen_aval, scaling_factor, is_training
qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
*batch_shape, max_seqlen, nqkv, num_head, head_dim = qkv_aval.shape
assert nqkv == 3
......@@ -1781,19 +1777,20 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
batch, num_head, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor,
dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(qkv, bias, squeezed_mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
def impl(qkv, bias, seqlen, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
assert SelfFusedAttnFwdPrimitive.inner_primitive is not None
cu_seqlen = generate_cu_seqlen(squeezed_mask)
cu_seqlen = generate_cu_seqlen(seqlen)
output, softmax_aux, rng_state = SelfFusedAttnFwdPrimitive.inner_primitive.bind(
qkv,
......@@ -1859,10 +1856,9 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
register_primitive(SelfFusedAttnFwdPrimitive)
def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, squeezed_mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, seqlen: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Wrapper for TE self fused attention fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
......@@ -1875,7 +1871,7 @@ def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, squeezed_mask: jnp.
bias = jnp.zeros(0, dtype=qkv.dtype)
return SelfFusedAttnFwdPrimitive.outer_primitive.bind(qkv,
bias,
squeezed_mask,
seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
......@@ -1896,14 +1892,14 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
@staticmethod
def abstract(qkv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, doutput_aval,
mask_or_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor,
seqlen_or_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
"""
Self fused attention bwd abstract
"""
del softmax_aux_aval, rng_state_aval
# outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
del mask_or_cu_seqlen_aval, attn_bias_type, attn_mask_type
# outer_primitve is seqlen, inner_primitive is cu_seqlen
del seqlen_or_cu_seqlen_aval, attn_bias_type, attn_mask_type
del scaling_factor, dropout_probability, is_training
qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
......@@ -1934,19 +1930,20 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
batch, num_head, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor,
dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(qkv, bias, softmax_aux, rng_state, output, doutput, squeezed_mask, attn_bias_type,
def impl(qkv, bias, softmax_aux, rng_state, output, doutput, seqlen, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
assert SelfFusedAttnBwdPrimitive.inner_primitive is not None
cu_seqlen = generate_cu_seqlen(squeezed_mask)
cu_seqlen = generate_cu_seqlen(seqlen)
dqkv, dbias = SelfFusedAttnBwdPrimitive.inner_primitive.bind(
qkv,
......@@ -2029,7 +2026,7 @@ register_primitive(SelfFusedAttnBwdPrimitive)
def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray,
rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray,
squeezed_mask: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
......@@ -2045,7 +2042,7 @@ def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.nd
rng_state,
output,
doutput,
squeezed_mask,
seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
......@@ -2064,13 +2061,13 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
outer_primitive = None
@staticmethod
def abstract(q_aval, kv_aval, bias_aval, q_mask_or_cu_seqlen_aval, kv_mask_or_cu_seqlen_aval,
seed_aval, *, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training):
def abstract(q_aval, kv_aval, bias_aval, q_seqlen_or_cu_seqlen_aval,
kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
"""
Cross fused attention fwd abstract
"""
# outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
# outer_primitve is seqlen, inner_primitive is cu_seqlen
del scaling_factor, is_training
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
......@@ -2083,18 +2080,17 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
assert q_dtype == kv_dtype == bias_dtype
assert q_batch_shape == kv_batch_shape
assert q_num_head == kv_num_head
assert q_head_dim == kv_head_dim
assert nkv == 2
assert q_mask_or_cu_seqlen_aval.dtype == kv_mask_or_cu_seqlen_aval.dtype
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
output_shape = q_aval.shape
output_dtype = q_dtype
backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
attn_bias_type, attn_mask_type, dropout_probability,
q_num_head, kv_num_head,
q_max_seqlen, kv_max_seqlen, q_head_dim).get_fused_attn_backend()
attn_bias_type, attn_mask_type, dropout_probability, q_num_head,
kv_num_head, q_max_seqlen, kv_max_seqlen,
q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, kv_max_seqlen)
......@@ -2128,7 +2124,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
*batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
batch = reduce(operator.mul, batch_shape)
kv_max_seqlen = kv_aval.shape[-4]
kv_max_seqlen, kv_num_head = kv_aval.shape[-4], kv_aval.shape[-2]
operands = [q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed]
operand_shapes = map(lambda x: x.type.shape, operands)
......@@ -2139,7 +2135,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim,
batch, num_head, kv_num_head, q_max_seqlen, kv_max_seqlen, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
......@@ -2148,12 +2144,12 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
return out
@staticmethod
def impl(q, kv, bias, q_squeezed_mask, kv_squeezed_mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
def impl(q, kv, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
assert CrossFusedAttnFwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask)
kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask)
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
output, softmax_aux, rng_state = CrossFusedAttnFwdPrimitive.inner_primitive.bind(
q,
......@@ -2224,9 +2220,8 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
register_primitive(CrossFusedAttnFwdPrimitive)
def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
q_squeezed_mask: jnp.ndarray, kv_squeezed_mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, q_seqlen: jnp.ndarray,
kv_seqlen: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
......@@ -2243,8 +2238,8 @@ def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
return CrossFusedAttnFwdPrimitive.outer_primitive.bind(q,
kv,
bias,
q_squeezed_mask,
kv_squeezed_mask,
q_seqlen,
kv_seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
......@@ -2296,7 +2291,7 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
*batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
batch = reduce(operator.mul, batch_shape)
kv_max_seqlen = kv_aval.shape[-4]
kv_max_seqlen, kv_num_head = kv_aval.shape[-4], kv_aval.shape[-2]
operands = [q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen]
operand_shapes = map(lambda x: x.type.shape, operands)
......@@ -2310,7 +2305,7 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
# the dropout elements are encoded in the forward auxiliary tensor
# so seed is not needed in backward
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim,
batch, num_head, kv_num_head, q_max_seqlen, kv_max_seqlen, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
......@@ -2319,13 +2314,12 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
return out
@staticmethod
def impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_squeezed_mask,
kv_squeezed_mask, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training):
def impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen,
attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training):
assert CrossFusedAttnBwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask)
kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask)
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
dq, dkv, dbias = CrossFusedAttnBwdPrimitive.inner_primitive.bind(
q,
......@@ -2417,10 +2411,9 @@ register_primitive(CrossFusedAttnBwdPrimitive)
def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray,
doutput: jnp.ndarray, q_squeezed_mask: jnp.ndarray,
kv_squeezed_mask: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Wrapper for TE cross fused attention bwd
Return the gradients of cross fused attention with packed kv input
......@@ -2435,8 +2428,8 @@ def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
rng_state,
output,
doutput,
q_squeezed_mask,
kv_squeezed_mask,
q_seqlen,
kv_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
......
......@@ -82,12 +82,12 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch,
}
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
size_t batch, size_t num_head, size_t num_gqa_groups, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{batch, num_head, q_max_seqlen, kv_max_seqlen,
head_dim, scaling_factor, dropout_probability,
bias_type, mask_type, dtype, is_training});
return PackOpaque(CustomCallFusedAttnDescriptor{
batch, num_head, num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, scaling_factor,
dropout_probability, bias_type, mask_type, dtype, is_training});
}
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
......@@ -745,8 +745,8 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
size_t head_dim) {
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_num_heads, kv_num_heads,
q_max_seqlen, kv_max_seqlen, head_dim);
mask_type, dropout_probability, q_num_heads, kv_num_heads, q_max_seqlen, kv_max_seqlen,
head_dim);
return backend;
}
......@@ -768,6 +768,7 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
......@@ -779,6 +780,9 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
NVTE_CHECK(num_head == num_gqa_groups,
"num_head should be equal to num_gqa_groups in the qkvpacked attention");
auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
......@@ -799,10 +803,10 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
// aux tensors
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_head, num_head,
q_max_seqlen, kv_max_seqlen, head_dim);
auto backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype),
qkv_layout, bias_type, mask_type, dropout_probability, num_head,
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
NVTETensorPack aux_output_tensors;
......@@ -853,6 +857,7 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
......@@ -864,6 +869,9 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
NVTE_CHECK(num_head == num_gqa_groups,
"num_head should be equal to num_gqa_groups in the qkvpacked attention");
auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim};
auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
......@@ -941,6 +949,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
......@@ -951,7 +960,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_head, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
// input tensors
......@@ -976,10 +985,10 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_head, num_head,
q_max_seqlen, kv_max_seqlen, head_dim);
auto backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype),
qkv_layout, bias_type, mask_type, dropout_probability, num_head,
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
NVTETensorPack aux_output_tensors;
......@@ -1035,6 +1044,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
......@@ -1045,7 +1055,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_head, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
......
......@@ -98,6 +98,7 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch,
struct CustomCallFusedAttnDescriptor {
size_t batch;
size_t num_head;
size_t num_gqa_groups;
size_t q_max_seqlen;
size_t kv_max_seqlen;
size_t head_dim;
......@@ -110,8 +111,8 @@ struct CustomCallFusedAttnDescriptor {
};
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
size_t batch, size_t num_head, size_t num_gqa_groups, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, DType dtype, bool is_training);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
......
......@@ -16,7 +16,6 @@ import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import dtypes
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
......@@ -198,22 +197,31 @@ def core_attention(query: Array,
batch_dim = 1 if transpose_batch_sequence else 0
assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
'q, k, v batch dims must match.')
assert query.shape[-2] == key.shape[-2] == value.shape[-2], ('q, k, v num_heads must match.')
sequence_dim = 0 if transpose_batch_sequence else 1
assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.'
if float32_logits:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
h_q, h_kv = query.shape[-2], key.shape[-2]
assert (h_q % h_kv == 0) and (h_q >= h_kv)
group_size = h_q // h_kv
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
if transpose_batch_sequence:
attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key)
attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
attn_weights = checkpoint_name(attn_weights, 'logits')
b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
attn_weights = _with_sharding_constraint(attn_weights,
(BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))
......@@ -229,6 +237,8 @@ def core_attention(query: Array,
attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype)
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate
dropout_shape = list(attn_weights.shape)
......@@ -238,9 +248,9 @@ def core_attention(query: Array,
attn_weights = attn_weights * multiplier
if transpose_batch_sequence:
return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value)
return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
......@@ -262,6 +272,14 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
The hidden dimension of each attention head.
num_heads : int
The number of attention heads
num_gqa_groups : int, default = `None`
Number of GQA groups. When `None` is present, it is equal to num_heads.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
dropout_rate : float, default = 0.0
Dropout probability for the dropout op during multi-head attention.
dropout_rng_name: str, default = 'dropout'
......@@ -321,6 +339,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
head_dim: int
num_heads: int
num_gqa_groups: int | None = None
dropout_rate: float = 0.
dropout_rng_name: str = 'dropout'
layernorm_type: str = "layernorm"
......@@ -342,6 +361,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_heads
super().__post_init__()
@nn.compact
......@@ -428,30 +449,22 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
"supported attn_mask_type = {'causal', 'padding'}")
is_self_attn = (inputs_q is inputs_kv)
is_gqa = (self.num_heads != self.num_gqa_groups)
is_qkvpack = (is_self_attn and not is_gqa)
qkv_layout = QKVLayout.BS3HD if is_self_attn else QKVLayout.BSHD_BS2HD
attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
canonicalize_dtype = dtypes.canonicalize_dtype(self.dtype)
q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1]
kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1]
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))
def _check_seqlen(seqlen):
return seqlen % 64 == 0
def _check_head_dim(head_dim):
return head_dim in [64, 128]
has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout,
attn_bias_type, attn_mask_type,
self.dropout_rate,
self.num_heads, self.num_heads,
q_seqlen, kv_seqlen, self.head_dim)
self.dropout_rate, self.num_heads,
self.num_gqa_groups, q_seqlen,
kv_seqlen, self.head_dim)
use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \
_check_seqlen(q_seqlen) and _check_seqlen(kv_seqlen) and \
_check_head_dim(self.head_dim) and \
has_fused_attn_kernel and \
enable_fused_attn
......@@ -464,17 +477,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
f"but got {self.transpose_batch_sequence}, "
if not self.fuse_qkv:
reason += f"fuse_qkv=True is required but got {self.fuse_qkv}, "
if canonicalize_dtype not in [jnp.bfloat16, jnp.float16]:
reason += f"dtype in [BF16, FP16] is required " \
f"but got dtype={canonicalize_dtype}, "
if not _check_seqlen(q_seqlen):
reason += f"q_seqlen % 64 == 0 is required " \
f"but got {q_seqlen=}, "
if not _check_seqlen(kv_seqlen):
reason += f"kv_seqlen % 64 == 0 is required " \
f"but got {kv_seqlen=}, "
if not _check_head_dim(self.head_dim):
reason += f"head_dim should be 64 or 128 but got {self.head_dim}, "
if not has_fused_attn_kernel:
reason += "no fused attention kernel is available, "
......@@ -484,7 +486,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
residual = inputs_q
if self.fuse_qkv:
if is_self_attn:
if is_qkvpack:
qkv_proj, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm,
layernorm_type=self.layernorm_type,
......@@ -515,7 +517,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
axis=-1,
features=self.num_heads * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm,
return_layernorm_output=(self.apply_residual_connection_post_layernorm
or is_self_attn),
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes=(W_FSDP_AXES, W_TP_AXES),
......@@ -525,8 +528,13 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
dtype=self.dtype,
kernel_init=query_init,
name='query')(inputs_q)
if is_self_attn:
assert ln_out is not None
inputs_kv = ln_out
kv_proj = DenseGeneral(axis=-1,
features=(2, self.num_heads * self.head_dim),
features=(2, self.num_gqa_groups * self.head_dim),
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_init=kv_init,
......@@ -542,7 +550,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kv_projection = functools.partial(
DenseGeneral,
axis=-1,
features=self.num_heads * self.head_dim,
features=self.num_gqa_groups * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_TP_AXES),
use_bias=self.use_bias,
......@@ -583,9 +591,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
query = checkpoint_name(query, 'query_proj')
key = checkpoint_name(key, 'key_proj')
value = checkpoint_name(value, 'value_proj')
query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim))
key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim))
value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim))
query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
qkv_sharding_constraint = \
(SEQLEN_AXES, BATCH_AXES, HEAD_AXES, HIDDEN_AXES) \
if self.transpose_batch_sequence \
......@@ -650,7 +658,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
# ensure the old key never used
del dropout_rng
if is_self_attn:
if is_qkvpack:
qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
HIDDEN_AXES)
......@@ -667,7 +675,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
else:
assert bias is None
query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim))
kv_proj = kv_proj.reshape((*kv_proj.shape[:-1], self.num_heads, self.head_dim))
kv_proj = kv_proj.reshape((*kv_proj.shape[:-1], self.num_gqa_groups, self.head_dim))
q_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
kv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
HIDDEN_AXES)
......@@ -865,6 +873,14 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Intermediate size to which input samples are projected.
num_attention_heads: int, default = 8
Number of attention heads in the transformer layer.
num_gqa_groups : int, default = `None`
Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
Indicate the type of layer normalization.
layernorm_epsilon: float, default = 1e-6
......@@ -961,6 +977,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
hidden_size: int = 512
mlp_hidden_size: int = 2048
num_attention_heads: int = 8
num_gqa_groups: int | None = None
layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
......@@ -995,6 +1012,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
if self.mlp_kernel_init is None:
self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in',
'truncated_normal')
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact
......@@ -1091,6 +1110,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
num_heads=self.num_attention_heads,
dtype=self.dtype,
head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name,
......@@ -1141,6 +1161,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
num_heads=self.num_attention_heads,
dtype=self.dtype,
head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name,
......
......@@ -40,8 +40,8 @@ class QKVLayout(Enum):
def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type,
dropout_probability, num_heads_q, num_heads_kv,
max_seqlen_q, max_seqlen_kv, head_dim):
dropout_probability, num_heads_q, num_heads_kv, max_seqlen_q,
max_seqlen_kv, head_dim):
"""
To check whether the fused attention kernel is available
"""
......@@ -83,10 +83,11 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda
seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
squeezed_mask = mask[..., 0]
mask = jnp.logical_not(mask)
actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
bias,
squeezed_mask,
actual_seqlen,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
......@@ -96,12 +97,12 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda
output = checkpoint_name(output, 'context')
softmax_aux = checkpoint_name(softmax_aux, 'context')
rng_state = checkpoint_name(rng_state, 'context')
return output, (qkv, bias, softmax_aux, rng_state, output, squeezed_mask)
return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen)
def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz):
qkv, bias, softmax_aux, rng_state, output, squeezed_mask = ctx
qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx
grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
bias,
......@@ -109,7 +110,7 @@ def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dr
rng_state,
output,
dz,
squeezed_mask,
actual_seqlen,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
......@@ -159,14 +160,19 @@ def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask:
def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
q_squeezed_mask = mask[..., 0]
kv_squeezed_mask = mask[..., 0, :]
mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else:
# When mask is padding + causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = cross_fused_attn_fwd(q,
kv,
bias,
q_squeezed_mask,
kv_squeezed_mask,
q_actual_seqlen,
kv_actual_seqlen,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
......@@ -174,12 +180,12 @@ def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mas
dropout_probability=dropout_probability,
is_training=is_training)
return output, (q, kv, bias, softmax_aux, rng_state, output, q_squeezed_mask, kv_squeezed_mask)
return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen)
def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz):
q, kv, bias, softmax_aux, rng_state, output, q_squeezed_mask, kv_squeezed_mask = ctx
q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
grad_q, grad_kv, grad_bias = cross_fused_attn_bwd(q,
kv,
......@@ -188,8 +194,8 @@ def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, d
rng_state,
output,
dz,
q_squeezed_mask,
kv_squeezed_mask,
q_actual_seqlen,
kv_actual_seqlen,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
......
......@@ -64,6 +64,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
head_dim: int = 64
num_heads: int = 16
num_gqa_groups: int | None = None
dropout_rate: float = 0.
dropout_rng_name: str = 'dropout'
layernorm_type: str = "layernorm"
......@@ -80,6 +81,11 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
scaled_query_init: bool = True
float32_logits: bool = False
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_heads
super().__post_init__()
def setup(self) -> None:
"""setup"""
super().setup()
......@@ -89,6 +95,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
dtype=self.dtype,
head_dim=self.head_dim,
num_heads=self.num_heads,
num_gqa_groups=self.num_gqa_groups,
dropout_rate=self.dropout_rate,
dropout_rng_name=self.dropout_rng_name,
layernorm_type=self.layernorm_type,
......@@ -131,6 +138,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
hidden_size: int = 512
mlp_hidden_size: int = 2048
num_attention_heads: int = 8
num_gqa_groups: int | None = None
layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
......@@ -156,6 +164,11 @@ class TransformerLayer(TransformerEngineBaseLayer):
scale_attn_logits: bool = False
scaled_query_init: bool = True
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
def setup(self) -> None:
"""setup"""
super().setup()
......@@ -186,6 +199,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
hidden_size=self.hidden_size,
mlp_hidden_size=self.mlp_hidden_size,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment