Unverified Commit 8f6c5248 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX][Common] Support GQA (#578)



* Support num_gqa_groups arguments
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add GQA support on the JAX bridge code
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix the kv stride of the arbitrary backend
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Complete rewrite fused attention tests and add GQA coverage
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Support unfused GQA
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Calculate seqlen before the primitive for the better perf
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add GQA layer tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Apply code style checks for te_jax
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Apply code style checks for tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add num_gqa_groups doc
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine the qkv_type
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Correct the variable naming
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Handle Max512 CAUSAL
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add WAR for the latest jax image
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent daad219f
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
set -xe set -xe
# WAR(rewang) for the "Check failed: reduction_kind.has_value()"
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_xla_runtime_executable=true"
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_* pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_*
...@@ -14,5 +14,7 @@ pytest -Wignore -v $TE_PATH/examples/jax/mnist ...@@ -14,5 +14,7 @@ pytest -Wignore -v $TE_PATH/examples/jax/mnist
# Make encoder tests to have run-to-run deterministic to have the stable CI results # Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
# WAR(rewang) for the "Check failed: reduction_kind.has_value()"
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_xla_runtime_executable=true"
pytest -Wignore -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py pytest -Wignore -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Tests for fused attention""" """Tests for fused attention"""
import os from dataclasses import dataclass
from enum import Enum from functools import partial
from math import sqrt from math import sqrt
import jax import jax
...@@ -13,18 +13,15 @@ import numpy as np ...@@ -13,18 +13,15 @@ import numpy as np
import pytest import pytest
from flax.linen import combine_masks from flax.linen import combine_masks
from flax.linen import dot_product_attention
from flax.linen import make_attention_mask from flax.linen import make_attention_mask
from flax.linen import make_causal_mask from flax.linen.dtypes import promote_dtype
from jax import Array
from jax import value_and_grad, jit from jax import value_and_grad, jit
from jax.typing import ArrayLike, DTypeLike
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine_jax import get_device_compute_capability # pylint: disable=wrong-import-order
# Type annotations
Array = jnp.ndarray
@pytest.fixture(autouse=True, scope='function') @pytest.fixture(autouse=True, scope='function')
...@@ -32,34 +29,55 @@ def clear_live_arrays(): ...@@ -32,34 +29,55 @@ def clear_live_arrays():
""" """
Clear all live arrays to keep the resource clean Clear all live arrays to keep the resource clean
""" """
# Calling customcalls before jax may cause CUDA uninitialize error
_ = jnp.zeros(0)
yield yield
for arr in jax.live_arrays(): for arr in jax.live_arrays():
arr.delete() arr.delete()
class Backend(Enum): def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: ArrayLike,
bias: ArrayLike, mask: ArrayLike, deterministic: bool,
dropout_rate: float, dropout_rng: ArrayLike,
dtype: DTypeLike) -> Array:
""" """
Fused attn backend. Similar to flax.linen.dot_product_attention but with GQA support
Unit tests only, transformer will auto dispatch to the best backend
""" """
Max512 = "0" query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
Arbitrary = "1" dtype = query.dtype
depth = query.shape[-1]
query = query / jnp.sqrt(depth).astype(dtype)
@pytest.fixture(name="backend", params=[Backend.Max512, Backend.Arbitrary]) b, s_q, h_q, d = query.shape
def fixture_backend(request): _, _, h_kv, _ = key.shape
""" assert (h_q % h_kv == 0) and (h_q >= h_kv)
Fixture of setting up/tearing down backend num_groups = h_q // h_kv
""" grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
backend = request.param # logits with shape (b, h_kv, num_groups, s_q, s_kv)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = backend.value logits = jnp.einsum('...qhgd,...khd->...hgqk', grouped_query, key)
yield backend
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "" if bias is not None:
if bias.ndim != logits.ndim:
bias = bias.reshape((1, *logits.shape[1:]))
logits = logits + bias
if mask is not None:
if mask.ndim != logits.ndim:
mask = jnp.expand_dims(mask, axis=-3)
logits = jnp.where(mask, logits, jnp.finfo(dtype).min)
softmax_out = jax.nn.softmax(logits).astype(dtype)
if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate
keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
softmax_out = softmax_out * multiplier
SELF_CASES = [(32, 512, 16, 64), (32, 128, 16, 64), (4, 2048, 12, 64)] context = jnp.einsum('...hgqk,...khd->...qhgd', softmax_out, value)
CROSS_CASES = [(32, 128, 512, 16, 64)] context = jnp.reshape(context, query.shape)
DTYPES = [jnp.bfloat16, jnp.float16] return context
def is_causal_mask(mask: AttnMaskType): def is_causal_mask(mask: AttnMaskType):
...@@ -69,31 +87,28 @@ def is_causal_mask(mask: AttnMaskType): ...@@ -69,31 +87,28 @@ def is_causal_mask(mask: AttnMaskType):
return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK] return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]
def make_decoder_mask(tokens: Array) -> Array: def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
""" """
Create padded causal mask Create padded causal mask
""" """
causal_mask = make_causal_mask(tokens) q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape)
padding_mask = make_attention_mask(tokens > 0, tokens > 0) kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape)
causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal)
padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0)
return combine_masks(causal_mask, padding_mask) return combine_masks(causal_mask, padding_mask)
def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs): def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
""" """
Self attention with JAX native implementation JAX native dot product attention implementation
""" """
attn_mask_type = kwargs['attn_mask_type'] attn_mask_type = kwargs['attn_mask_type']
if is_causal_mask(attn_mask_type): if is_causal_mask(attn_mask_type):
mask = make_decoder_mask(q_token) mask = make_decoder_mask(q_token, kv_token)
else: else:
mask = make_attention_mask(q_token > 0, kv_token > 0) mask = make_attention_mask(q_token > 0, kv_token > 0)
query, key, value = jnp.split(qkv, [1, 2], axis=-3) output = general_dot_product_attention(query,
query = jnp.squeeze(query)
key = jnp.squeeze(key)
value = jnp.squeeze(value)
output = dot_product_attention(query,
key, key,
value, value,
bias=bias, bias=bias,
...@@ -102,485 +117,259 @@ def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs): ...@@ -102,485 +117,259 @@ def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
dropout_rate=kwargs['dropout_probability'], dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng, dropout_rng=dropout_rng,
dtype=jnp.float32) dtype=jnp.float32)
return output.astype(qkv.dtype) return output.astype(query.dtype)
def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs): def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
""" """
Cross attention with JAX native implementation TE customcall dot product attention implementation
""" """
assert q.dtype == kv.dtype
attn_mask_type = kwargs['attn_mask_type'] attn_mask_type = kwargs['attn_mask_type']
if is_causal_mask(attn_mask_type): if is_causal_mask(attn_mask_type):
raise NotImplementedError mask = make_decoder_mask(q_token, kv_token)
mask = make_attention_mask(q_token > 0, kv_token > 0)
query = q
key, value = jnp.split(kv, [1], axis=-3)
key = jnp.squeeze(key)
value = jnp.squeeze(value)
output = dot_product_attention(query,
key,
value,
bias=None,
mask=mask,
deterministic=not kwargs['is_training'],
dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng,
dtype=jnp.float32)
return output.astype(q.dtype)
def customcall_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
"""
Self fused attention
"""
attn_mask_type = kwargs['attn_mask_type']
if is_causal_mask(attn_mask_type):
mask = make_decoder_mask(q_token)
else: else:
mask = make_attention_mask(q_token > 0, kv_token > 0) mask = make_attention_mask(q_token > 0, kv_token > 0)
# mask invert # mask invert
mask = (mask == 0) mask = jnp.logical_not(mask)
return self_fused_attn(qkv, bias, mask, dropout_rng, **kwargs) qkv_layout = kwargs.pop('qkv_layout')
match qkv_layout:
case QKVLayout.BS3HD:
def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs): query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
qkv = jnp.concatenate((query, key, value), axis=-3)
return self_fused_attn(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
case QKVLayout.BSHD_BS2HD:
key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
kv = jnp.concatenate((key, value), axis=-3)
return cross_fused_attn(query, kv, bias, mask, dropout_rng,
**kwargs).astype(query.dtype)
@dataclass
class FusedAttnRunner:
""" """
Cross fused attention Fused attention runner
""" """
assert q.dtype == kv.dtype batch_size: int
max_seqlen_q: int
attn_mask_type = kwargs['attn_mask_type'] max_seqlen_kv: int
if is_causal_mask(attn_mask_type): num_heads_q: int
raise NotImplementedError num_heads_kv: int
mask = make_attention_mask(q_token > 0, kv_token > 0) head_dim: int
attn_bias_type: AttnBiasType
# mask invert attn_mask_type: AttnMaskType
mask = (mask == 0) dropout_prob: float
dtype: DTypeLike
return cross_fused_attn(q, kv, None, mask, dropout_rng, **kwargs) is_training: bool
qkv_layout: QKVLayout
def _check_configs(self):
if self.qkv_layout == QKVLayout.BS3HD and self.num_heads_q != self.num_heads_kv:
pytest.skip("BS3HD layout requires num_heads_q and num_heads_kv to be equal.")
if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")
if not is_fused_attn_kernel_available(
self.dtype, self.dtype, self.qkv_layout, self.attn_bias_type, self.attn_mask_type,
self.dropout_prob, self.num_heads_q, self.num_heads_kv, self.max_seqlen_q,
self.max_seqlen_kv, self.head_dim):
pytest.skip("Unsupported inputs combination or device compute capability.")
def _setup_inputs(self):
self._check_configs()
key = jax.random.PRNGKey(0)
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
@pytest.mark.parametrize('b, s, h, d', SELF_CASES) q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS]) k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim)
@pytest.mark.parametrize('attn_mask_type', [ bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK,
AttnMaskType.PADDING_CAUSAL_MASK
])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
class TestSelfFusedAttn():
"""Tests for transformer_engine.jax.fused_attn.self_fused_attn"""
@staticmethod self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1)
def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probability, dtype, self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1)
num_heads_q, num_heads_kv, head_dim): self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1)
assert isinstance(backend, Backend) with_bias = self.attn_bias_type != AttnBiasType.NO_BIAS
self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1) if with_bias else None
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type, if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
attn_mask_type, dropout_probability,
num_heads_q, num_heads_kv,
s, s, head_dim):
pytest.skip("Unsupported inputs combination or device compute capability.")
def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
dropout_probability, dtype, is_training):
"""Setup the test inputs"""
self.__class__._check_inputs(s,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
num_heads_q=h,
num_heads_kv=h,
head_dim=d)
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pad_ratio = 0.0 pad_ratio = 0.0
else: else:
pad_ratio = 0.3 pad_ratio = 0.3
key = jax.random.PRNGKey(0) def gen_valid(bs, max_seqlen, pad_ratio):
subkeys = jax.random.split(key, 2) pad_len = int(max_seqlen * pad_ratio)
valid_len = max_seqlen - pad_len
tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
return valid_len, tokens
qkv_shape = (b, s, 3, h, d) self.valid_len_q, self.token_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio)
bias_shape = (1, h, s, s) self.valid_len_kv, self.token_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
pad_len = int(s * pad_ratio) self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.valid_len = s - pad_len self.scaling_factor = 1. / sqrt(self.head_dim)
min_val, max_val = -1, 1 def test_forward(self):
self.qkv = jax.random.uniform(subkeys[0], qkv_shape, dtype, min_val, max_val) """
Test forward without JIT
"""
self._setup_inputs()
with_bias = attn_bias_type != AttnBiasType.NO_BIAS args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
self.bias = jax.random.uniform(subkeys[1], bias_shape, dtype, min_val, kwargs = {
max_val) if with_bias else None 'attn_bias_type': self.attn_bias_type,
'attn_mask_type': self.attn_mask_type,
'scaling_factor': self.scaling_factor,
'dropout_probability': self.dropout_prob,
'is_training': self.is_training,
'qkv_layout': self.qkv_layout,
}
self.q_token = jnp.concatenate((jnp.ones((b, self.valid_len)), jnp.zeros((b, pad_len))), # Convert the outputs to float32 for the elementwise comparison
axis=-1) primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32)
self.kv_token = self.q_token reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32)
self.scaling_factor = 1. / sqrt(d) primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1)
self.dropout_probability = dropout_probability reference_valid, _ = jnp.split(reference_out, (self.valid_len_q,), axis=1)
self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None
self.attn_bias_type = attn_bias_type
self.attn_mask_type = attn_mask_type
self.is_training = is_training
def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, dropout_probability, # Skip elementwise comparison when dropout enabled
dtype, is_training): if self.is_training and self.dropout_prob > 0.:
"""
Test forward without using JIT
"""
self._set_inputs(b,
s,
h,
d,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training)
primitive_out = customcall_self_fused_attn(self.qkv,
self.bias,
self.q_token,
self.kv_token,
self.dropout_rng,
attn_bias_type=self.attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability,
is_training=self.is_training)
reference_out = jax_self_attn(self.qkv,
self.bias,
self.q_token,
self.kv_token,
self.dropout_rng,
attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability,
is_training=self.is_training)
ref_valid, _ = jnp.split(reference_out, (self.valid_len,), axis=1)
pri_valid, pri_invalid = jnp.split(primitive_out, (self.valid_len,), axis=1)
# Dropout can't get the bitmatch result, skip the elementwise comparison
if is_training and dropout_probability > 0.:
return return
np.testing.assert_allclose(jnp.asarray(pri_valid, np.float32), np.testing.assert_allclose(primitive_valid, reference_valid, atol=1e-2, rtol=1e-4)
jnp.asarray(ref_valid, np.float32), np.testing.assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid))
rtol=1e-4,
atol=1e-2)
np.testing.assert_allclose(jnp.asarray(pri_invalid, jnp.float32),
jnp.zeros_like(pri_invalid, jnp.float32))
def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, def test_backward(self):
dropout_probability, dtype, is_training):
""" """
Test forward, backward, and autodiff by jax.value_and_grad Test value_and_grad with JIT, which includes both forward and backward
""" """
if not is_training: if not self.is_training:
pytest.skip(f"Backward doesn't support {is_training=}") pytest.skip("Backward doesn't support inference")
self._set_inputs(b, self._setup_inputs()
s,
h, def grad_func(func, *args, **kwargs):
d,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training)
def grad_func(fused_attn_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident # Gradient is small, use a gradient multiplier to amplify the graident
gradient_multiplier = 1000 if dtype == jnp.bfloat16 else 10000 gradient_multiplier = self.valid_len_q * self.num_heads_q
if is_causal_mask(attn_mask_type): if is_causal_mask(self.attn_mask_type):
gradient_multiplier = gradient_multiplier / 10 gradient_multiplier /= 10
# Keep only valid result for the gradient # Keep only valid result for the gradient
# fused_attn output has shape (b, s, h, d) ret_valid, _ = jnp.split(func(*args, **kwargs), (self.valid_len_q,), axis=1)
valid_fused_attn_ret, _ = jnp.split(fused_attn_func(*args, **kwargs), (self.valid_len,), return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype)
axis=1)
return (jnp.mean(valid_fused_attn_ret, dtype=jnp.float32) *
gradient_multiplier).astype(dtype)
args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
kwargs = { kwargs = {
'attn_bias_type': self.attn_bias_type, 'attn_bias_type': self.attn_bias_type,
'attn_mask_type': attn_mask_type, 'attn_mask_type': self.attn_mask_type,
'scaling_factor': self.scaling_factor, 'scaling_factor': self.scaling_factor,
'dropout_probability': self.dropout_probability, 'dropout_probability': self.dropout_prob,
'is_training': self.is_training 'is_training': self.is_training,
'qkv_layout': self.qkv_layout,
} }
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit( jitted_primitive = jit(
value_and_grad( value_and_grad(
lambda qkv, bias, q_token, kv_token, dropout_rng: grad_func( lambda q, k, v, bias, *args: grad_func(customcall_fused_dpa, q, k, v, bias, *args,
customcall_self_fused_attn, qkv, bias, q_token, kv_token, dropout_rng, **kwargs **kwargs), (0, 1, 2, 3)))
), (0, 1)))
jitted_reference = jit( jitted_reference = jit(
value_and_grad( value_and_grad(
lambda qkv, bias, q_token, kv_token, dropout_rng: grad_func( lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
jax_self_attn, qkv, bias, q_token, kv_token, dropout_rng, **kwargs), (0, 1))) (0, 1, 2, 3)))
primitive_out, (primitive_dqkv,
primitive_dbias) = jitted_primitive(self.qkv, self.bias, self.q_token,
self.kv_token, self.dropout_rng)
reference_out, (reference_dqkv, primitive_out, primitive_dgrad = jitted_primitive(*args)
reference_dbias) = jitted_reference(self.qkv, self.bias, self.q_token, reference_out, reference_dgrad = jitted_reference(*args)
self.kv_token, self.dropout_rng)
# Dropout can't get the bitmatch result, skip the elementwise comparison # Skip elementwise comparison when dropout enabled
if dropout_probability > 0.: if self.dropout_prob > 0.:
return return
np.testing.assert_allclose(jnp.asarray(primitive_out, np.float32), np.testing.assert_allclose(primitive_out.astype(jnp.float32),
jnp.asarray(reference_out, np.float32), reference_out.astype(jnp.float32),
rtol=1e-4, atol=1e-5,
atol=1e-5) rtol=1e-3)
valid_primitive_dqkv, invalid_primitive_dqkv = \
jnp.split(primitive_dqkv.astype(jnp.float32), (self.valid_len,), axis=1)
valid_reference_dqkv, invalid_reference_dqkv = \
jnp.split(reference_dqkv.astype(jnp.float32), (self.valid_len,), axis=1)
valid_primitive_dq, valid_primitive_dk, valid_primitive_dv = \ # Convert the outputs to float32 for the elementwise comparison
jnp.split(valid_primitive_dqkv, 3, axis=2) primitive_dq, primitive_dk, primitive_dv, primitive_dbias = map(
valid_reference_dq, valid_reference_dk, valid_reference_dv = \ jnp.float32, primitive_dgrad)
jnp.split(valid_reference_dqkv, 3, axis=2) reference_dq, reference_dk, reference_dv, reference_dbias = map(
jnp.float32, reference_dgrad)
np.testing.assert_allclose(valid_primitive_dq, valid_reference_dq, rtol=1e-4, atol=1e-5) def check_dqkv(primitive, reference, valid_len):
np.testing.assert_allclose(valid_primitive_dk, valid_reference_dk, rtol=1e-4, atol=1e-5) primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1)
np.testing.assert_allclose(valid_primitive_dv, valid_reference_dv, rtol=1e-4, atol=1e-5) reference_valid, reference_invalid = jnp.split(reference, (valid_len,), axis=1)
assert jnp.allclose(invalid_primitive_dqkv, invalid_reference_dqkv) np.testing.assert_allclose(primitive_valid, reference_valid, atol=1e-4, rtol=1e-3)
assert jnp.allclose(primitive_invalid, reference_invalid)
assert jnp.allclose(primitive_invalid, jnp.zeros_like(primitive_invalid))
# Padded part should be 0s check_dqkv(primitive_dq, reference_dq, self.valid_len_q)
assert jnp.allclose(invalid_primitive_dqkv, jnp.zeros_like(invalid_primitive_dqkv)) check_dqkv(primitive_dk, reference_dk, self.valid_len_kv)
check_dqkv(primitive_dv, reference_dv, self.valid_len_kv)
if self.attn_bias_type != AttnBiasType.NO_BIAS: if self.attn_bias_type != AttnBiasType.NO_BIAS:
# dbias valid part # dbias valid part
np.testing.assert_allclose( np.testing.assert_allclose(primitive_dbias[..., :self.valid_len_q, :self.valid_len_kv],
jnp.asarray(primitive_dbias[:, :, :self.valid_len, :self.valid_len], np.float32), reference_dbias[..., :self.valid_len_q, :self.valid_len_kv],
jnp.asarray(reference_dbias[:, :, :self.valid_len, :self.valid_len], np.float32), atol=3e-5,
rtol=1e-4, rtol=1e-4)
atol=3e-5)
# dbias padded part # dbias padded part
np.testing.assert_allclose( np.testing.assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
jnp.asarray(primitive_dbias[:, :, self.valid_len:, self.valid_len:], np.float32), reference_dbias[..., self.valid_len_q:, self.valid_len_kv:])
jnp.asarray(reference_dbias[:, :, self.valid_len:, self.valid_len:], np.float32))
assert jnp.allclose( assert jnp.allclose(
primitive_dbias[:, :, self.valid_len:, self.valid_len:], primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
jnp.zeros_like(primitive_dbias[:, :, self.valid_len:, self.valid_len:])) jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:]))
@pytest.mark.skipif(get_device_compute_capability(0) not in [80, 90],
reason="Fused attention kernel is not supported.")
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES)
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', [0.3])
class TestCrossFusedAttn():
"""Tests for transformer_engine.jax.fused_attn.cross_fused_attn"""
def _set_inputs(self, b, s_q, s_kv, h, d, *, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
q_shape = (b, s_q, h, d)
kv_shape = (b, s_kv, 2, h, d)
q_pad_len = int(s_q * pad_ratio)
kv_pad_len = int(s_kv * pad_ratio)
self.q_valid_len = s_q - q_pad_len
self.kv_valid_len = s_kv - kv_pad_len
min_val, max_val = -1, 1
self.q = jax.random.uniform(subkeys[0], q_shape, dtype, min_val, max_val)
self.kv = jax.random.uniform(subkeys[1], kv_shape, dtype, min_val, max_val)
self.q_token = jnp.concatenate((jnp.ones((b, self.q_valid_len)), jnp.zeros((b, q_pad_len))),
axis=-1)
self.kv_token = jnp.concatenate((jnp.ones((b, self.kv_valid_len)), jnp.zeros(
(b, kv_pad_len))),
axis=-1)
self.scaling_factor = 1. / sqrt(d)
self.dropout_probability = dropout_probability
self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None
self.attn_bias_type = AttnBiasType.NO_BIAS
self.attn_mask_type = attn_mask_type
self.is_training = is_training
def test_forward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
"""
Test forward without using JIT
"""
self._set_inputs(b,
s_q,
s_kv,
h,
d,
attn_mask_type=attn_mask_type,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
primitive_out = customcall_cross_fused_attn(self.q,
self.kv,
self.q_token,
self.kv_token,
self.dropout_rng,
attn_bias_type=self.attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability,
is_training=self.is_training)
reference_out = jax_cross_attn(self.q,
self.kv,
self.q_token,
self.kv_token,
self.dropout_rng,
attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability,
is_training=self.is_training)
# Dropout can't get the bitmatch result, skip the elementwise comparison
if is_training and dropout_probability > 0.:
return
ref_valid, _ = jnp.split(reference_out, (self.q_valid_len,), axis=1)
pri_valid, pri_invalid = jnp.split(primitive_out, (self.q_valid_len,), axis=1)
np.testing.assert_allclose(jnp.asarray(pri_valid, np.float32),
jnp.asarray(ref_valid, np.float32),
rtol=1e-4,
atol=2e-3)
np.testing.assert_allclose(jnp.asarray(pri_invalid, jnp.float32),
jnp.zeros_like(pri_invalid, jnp.float32))
def test_forward_backward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype, @pytest.mark.parametrize('attn_bias_type', [
is_training, pad_ratio): pytest.param(AttnBiasType.NO_BIAS, id='NO_BIAS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, id='POST_SCALE_BIAS'),
])
@pytest.mark.parametrize('attn_mask_type', [
pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'),
pytest.param(AttnMaskType.PADDING_MASK, id='PADDING'),
pytest.param(AttnMaskType.CAUSAL_MASK, id='CAUSAL'),
pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id='PADDING_CAUSAL'),
])
@pytest.mark.parametrize('qkv_layout', [
pytest.param(QKVLayout.BS3HD, id='qkvpacked'),
pytest.param(QKVLayout.BSHD_BS2HD, id='kvpacked'),
])
@pytest.mark.parametrize('dropout_prob', [0., 0.1])
@pytest.mark.parametrize('is_training',
[pytest.param(True, id='training'),
pytest.param(False, id='inference')])
@pytest.mark.parametrize(
'dtype', [pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16")])
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d',
[(32, 128, 128, 16, 16, 64), (4, 2048, 2048, 12, 12, 64),
pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-cross'),
pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA')])
class TestFusedAttn:
""" """
Test forward, backward, and autodiff by jax.value_and_grad Fused attention tester
""" """
if not is_training:
pytest.skip(f"Backward doesn't support {is_training=}")
self._set_inputs(b,
s_q,
s_kv,
h,
d,
attn_mask_type=attn_mask_type,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
def grad_func(fused_attn_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident
gradient_multiplier = 1e4
# Keep only valid result for the gradient
# fused_attn output has shape (b, s_q, h, d)
valid_fused_attn_ret, _ = jnp.split(fused_attn_func(*args, **kwargs),
(self.q_valid_len,),
axis=1)
return (jnp.mean(valid_fused_attn_ret, dtype=jnp.float32) *
gradient_multiplier).astype(dtype)
kwargs = {
'attn_bias_type': self.attn_bias_type,
'attn_mask_type': attn_mask_type,
'scaling_factor': self.scaling_factor,
'dropout_probability': self.dropout_probability,
'is_training': self.is_training
}
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
value_and_grad(
lambda q, kv, q_token, kv_token, dropout_rng: grad_func(
customcall_cross_fused_attn, q, kv, q_token, kv_token, dropout_rng, **kwargs),
(0, 1)))
jitted_reference = jit(
value_and_grad(
lambda q, kv, q_token, kv_token, dropout_rng: grad_func(
jax_cross_attn, q, kv, q_token, kv_token, dropout_rng, **kwargs), (0, 1)))
primitive_out, (primitive_dq,
primitive_dkv) = jitted_primitive(self.q, self.kv, self.q_token,
self.kv_token, self.dropout_rng)
reference_out, (reference_dq, @staticmethod
reference_dkv) = jitted_reference(self.q, self.kv, self.q_token, def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
self.kv_token, self.dropout_rng) dtype, is_training, qkv_layout):
"""
# Dropout can't get the bitmatch result, skip the elementwise comparison Test forward with parameterized configs
if dropout_probability > 0.: """
return runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, is_training, qkv_layout)
runner.test_forward()
np.testing.assert_allclose(jnp.asarray(primitive_out, np.float32), @staticmethod
jnp.asarray(reference_out, np.float32), def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
rtol=1e-4, dtype, is_training, qkv_layout):
atol=1e-5) """
Test backward with parameterized configs
valid_primitive_dq, invalid_primitive_dq = jnp.split(primitive_dq, (self.q_valid_len,), """
axis=1) runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
valid_reference_dq, invalid_reference_dq = jnp.split(reference_dq, (self.q_valid_len,), dropout_prob, dtype, is_training, qkv_layout)
axis=1) runner.test_backward()
valid_primitive_dkv, invalid_primitive_dkv = jnp.split(primitive_dkv, (self.kv_valid_len,),
axis=1)
valid_reference_dkv, invalid_reference_dkv = jnp.split(reference_dkv, (self.kv_valid_len,),
axis=1)
# dQ
np.testing.assert_allclose(jnp.asarray(valid_primitive_dq, np.float32),
jnp.asarray(valid_reference_dq, np.float32),
rtol=1e-4,
atol=1e-5)
# dK
np.testing.assert_allclose(jnp.asarray(valid_primitive_dkv[:, :, 0], np.float32),
jnp.asarray(valid_reference_dkv[:, :, 0], np.float32),
rtol=1e-4,
atol=1e-5)
# dV
np.testing.assert_allclose(jnp.asarray(valid_primitive_dkv[:, :, 1], np.float32),
jnp.asarray(valid_reference_dkv[:, :, 1], np.float32),
rtol=1e-4,
atol=1e-5)
assert jnp.allclose(invalid_primitive_dq, invalid_reference_dq)
assert jnp.allclose(invalid_primitive_dkv, invalid_reference_dkv)
# Padded part should be 0s
assert jnp.allclose(invalid_primitive_dq, jnp.zeros_like(invalid_primitive_dq))
assert jnp.allclose(invalid_primitive_dkv, jnp.zeros_like(invalid_primitive_dkv))
...@@ -9,13 +9,14 @@ import jax ...@@ -9,13 +9,14 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import pytest import pytest
from transformer_engine.common.recipe import Format
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from utils import assert_allclose from utils import assert_allclose
from utils import DecoderLayer as RefDecoderLayer from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer from utils import EncoderLayer as RefEncoderLayer
from transformer_engine.common.recipe import Format
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
...@@ -85,8 +86,13 @@ _KEY_OF_LAYERNORM_TYPE = 'layernorm_type' ...@@ -85,8 +86,13 @@ _KEY_OF_LAYERNORM_TYPE = 'layernorm_type'
_KEY_OF_ZERO_CENTERED_GAMMA = 'zero_centered_gamma' _KEY_OF_ZERO_CENTERED_GAMMA = 'zero_centered_gamma'
_KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence' _KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence'
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits" _KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
_KEY_OF_NUM_HEADS = 'num_attention_heads'
_KEY_OF_NUM_GQA_GROUPS = 'num_gqa_groups'
BASE_ATTRS = {_KEY_OF_TRANSPOSE_BS: True} BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_NUM_HEADS: 8,
}
ATTRS = [{ ATTRS = [{
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
...@@ -129,6 +135,9 @@ ATTRS = [{ ...@@ -129,6 +135,9 @@ ATTRS = [{
_KEY_OF_DROPOUT_RATE: 0.0, _KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')), _KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True _KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4
}] }]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
...@@ -137,21 +146,13 @@ ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] ...@@ -137,21 +146,13 @@ ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
class TestEncoderLayer: class TestEncoderLayer:
@staticmethod @staticmethod
def sync_params(ref, target, attrs): def sync_params(ref, target):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
unfreeze_target = flax.core.unfreeze(target) unfreeze_target = flax.core.unfreeze(target)
if fuse_qkv: unfreeze_attn_scope = unfreeze_target['attention']
unfreeze_target['attention']['qkv']['kernel'] = \ ref_attn_scope = ref['attention']
jnp.reshape(ref['attention']['qkv']['kernel'], for key in ref_attn_scope.keys():
unfreeze_target['attention']['qkv']['kernel'].shape) unfreeze_attn_scope[key]['kernel'] = \
else: ref_attn_scope[key]['kernel'].reshape(unfreeze_attn_scope[key]['kernel'].shape)
unfreeze_target['attention']['query']['kernel'] = \
ref['attention']['query']['kernel']
unfreeze_target['attention']['key']['kernel'] = \
ref['attention']['key']['kernel']
unfreeze_target['attention']['value']['kernel'] = \
ref['attention']['value']['kernel']
unfreeze_target['mlp']['wi_kernel'] = \ unfreeze_target['mlp']['wi_kernel'] = \
jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape) jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
unfreeze_target['mlp']['wo_kernel'] = \ unfreeze_target['mlp']['wo_kernel'] = \
...@@ -196,7 +197,7 @@ class TestEncoderLayer: ...@@ -196,7 +197,7 @@ class TestEncoderLayer:
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs, test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks) test_masks)
ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params, attrs) ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params)
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng) ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng) test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)
...@@ -242,7 +243,7 @@ class TestEncoderLayer: ...@@ -242,7 +243,7 @@ class TestEncoderLayer:
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs, test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks) test_masks)
ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params, attrs) ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params)
if FP8Helper.is_fp8_enabled(): if FP8Helper.is_fp8_enabled():
for _ in range(4): for _ in range(4):
...@@ -266,7 +267,10 @@ class TestEncoderLayer: ...@@ -266,7 +267,10 @@ class TestEncoderLayer:
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
def reorganize_test_wgrad(test_wgrad, attrs): def reorganize_test_wgrad(test_wgrad, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) num_heads = attrs.get(_KEY_OF_NUM_HEADS)
num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) and \
num_heads == num_gqa_groups
attn_name = 'attention' attn_name = 'attention'
unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad) unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad)
...@@ -280,10 +284,12 @@ class TestEncoderLayer: ...@@ -280,10 +284,12 @@ class TestEncoderLayer:
unfreeze_test_wgrad['pre_attention_layer_norm']['ln_bias'] = \ unfreeze_test_wgrad['pre_attention_layer_norm']['ln_bias'] = \
unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias'] unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias'] del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
if fuse_qkv:
unfreeze_test_wgrad[attn_name]['qkv']['kernel'] = \ for key in unfreeze_test_wgrad[attn_name].keys():
jnp.reshape(unfreeze_test_wgrad[attn_name]['qkv']['kernel'], unfreeze_test_wgrad[attn_name][key]['kernel'] = \
(unfreeze_test_wgrad[attn_name]['qkv']['kernel'].shape[0], -1)) jnp.reshape(unfreeze_test_wgrad[attn_name][key]['kernel'],
(unfreeze_test_wgrad[attn_name][key]['kernel'].shape[0], -1))
unfreeze_test_wgrad['pre_mlp_layer_norm'] = {} unfreeze_test_wgrad['pre_mlp_layer_norm'] = {}
unfreeze_test_wgrad['pre_mlp_layer_norm']['scale'] = \ unfreeze_test_wgrad['pre_mlp_layer_norm']['scale'] = \
unfreeze_test_wgrad['mlp']['scale'] unfreeze_test_wgrad['mlp']['scale']
...@@ -348,26 +354,14 @@ class TestEncoderLayer: ...@@ -348,26 +354,14 @@ class TestEncoderLayer:
class TestDecoderLayer: class TestDecoderLayer:
@staticmethod @staticmethod
def sync_params(ref, target, attrs): def sync_params(ref, target):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
unfreeze_target = flax.core.unfreeze(target) unfreeze_target = flax.core.unfreeze(target)
if fuse_qkv: for scope in ['self_attention', 'encoder_decoder_attention']:
unfreeze_target['self_attention']['qkv']['kernel'] = \ unfreeze_scope = unfreeze_target[scope]
jnp.reshape(ref['self_attention']['qkv']['kernel'], ref_scope = ref[scope]
unfreeze_target['self_attention']['qkv']['kernel'].shape) for key in unfreeze_scope.keys():
unfreeze_target['encoder_decoder_attention']['kv']['kernel'] = \ unfreeze_scope[key]['kernel'] = \
jnp.reshape(ref['encoder_decoder_attention']['kv']['kernel'], ref_scope[key]['kernel'].reshape(unfreeze_scope[key]['kernel'].shape)
unfreeze_target['encoder_decoder_attention']['kv']['kernel'].shape)
else:
unfreeze_target['self_attention']['query']['kernel'] = \
ref['self_attention']['query']['kernel']
unfreeze_target['self_attention']['key']['kernel'] = \
ref['self_attention']['key']['kernel']
unfreeze_target['self_attention']['value']['kernel'] = \
ref['self_attention']['value']['kernel']
unfreeze_target['encoder_decoder_attention']['query']['kernel'] = \
ref['encoder_decoder_attention']['query']['kernel']
unfreeze_target['mlp']['wi_kernel'] = \ unfreeze_target['mlp']['wi_kernel'] = \
jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape) jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
unfreeze_target['mlp']['wo_kernel'] = \ unfreeze_target['mlp']['wo_kernel'] = \
...@@ -412,7 +406,7 @@ class TestDecoderLayer: ...@@ -412,7 +406,7 @@ class TestDecoderLayer:
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs, test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks) test_masks)
ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params, attrs) ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params)
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng) ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng) test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)
...@@ -459,7 +453,7 @@ class TestDecoderLayer: ...@@ -459,7 +453,7 @@ class TestDecoderLayer:
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs, test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks) test_masks)
ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params, attrs) ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params)
if FP8Helper.is_fp8_enabled(): if FP8Helper.is_fp8_enabled():
for _ in range(4): for _ in range(4):
...@@ -483,11 +477,14 @@ class TestDecoderLayer: ...@@ -483,11 +477,14 @@ class TestDecoderLayer:
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
def reorganize_test_wgrad(test_wgrad, attrs): def reorganize_test_wgrad(test_wgrad, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) num_heads = attrs.get(_KEY_OF_NUM_HEADS)
attn_name = 'self_attention' num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) and \
num_heads == num_gqa_groups
unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad) unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad)
if "output_layernorm" not in attrs: if "output_layernorm" not in attrs:
attn_name = 'self_attention'
unfreeze_test_wgrad['pre_self_attention_layer_norm'] = {} unfreeze_test_wgrad['pre_self_attention_layer_norm'] = {}
pre_attn_layer_key = 'qkv' if fuse_qkv else 'query' pre_attn_layer_key = 'qkv' if fuse_qkv else 'query'
unfreeze_test_wgrad['pre_self_attention_layer_norm']['scale'] = \ unfreeze_test_wgrad['pre_self_attention_layer_norm']['scale'] = \
...@@ -498,14 +495,11 @@ class TestDecoderLayer: ...@@ -498,14 +495,11 @@ class TestDecoderLayer:
unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias'] unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias'] del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
if fuse_qkv: for scope in ['self_attention', 'encoder_decoder_attention']:
unfreeze_test_wgrad[attn_name]['qkv']['kernel'] = \ for key in unfreeze_test_wgrad[scope].keys():
jnp.reshape(unfreeze_test_wgrad[attn_name]['qkv']['kernel'], unfreeze_test_wgrad[scope][key]['kernel'] = \
(unfreeze_test_wgrad[attn_name]['qkv']['kernel'].shape[0], -1)) jnp.reshape(unfreeze_test_wgrad[scope][key]['kernel'],
attn_name = 'encoder_decoder_attention' (unfreeze_test_wgrad[scope][key]['kernel'].shape[0], -1))
unfreeze_test_wgrad[attn_name]['kv']['kernel'] = \
jnp.reshape(unfreeze_test_wgrad[attn_name]['kv']['kernel'],
(unfreeze_test_wgrad[attn_name]['kv']['kernel'].shape[0], -1))
unfreeze_test_wgrad['pre_cross_attention_layer_norm'] = {} unfreeze_test_wgrad['pre_cross_attention_layer_norm'] = {}
unfreeze_test_wgrad['pre_cross_attention_layer_norm']['scale'] = \ unfreeze_test_wgrad['pre_cross_attention_layer_norm']['scale'] = \
......
...@@ -12,6 +12,8 @@ from praxis import pax_fiddle ...@@ -12,6 +12,8 @@ from praxis import pax_fiddle
from praxis.base_layer import WeightInit, DEFAULT_INIT_MUTABLE_LIST from praxis.base_layer import WeightInit, DEFAULT_INIT_MUTABLE_LIST
import pytest import pytest
from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax import fp8_autocast, update_fp8_metas, update_collections from transformer_engine.jax import fp8_autocast, update_fp8_metas, update_collections
from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral
...@@ -23,12 +25,12 @@ from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLaye ...@@ -23,12 +25,12 @@ from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLaye
from transformer_engine.jax.flax.module import Softmax from transformer_engine.jax.flax.module import Softmax
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from transformer_engine.jax.praxis import LayerNorm from transformer_engine.jax.praxis import LayerNorm
from transformer_engine.jax.praxis import FusedSoftmax, LayerNorm from transformer_engine.jax.praxis import FusedSoftmax
from transformer_engine.jax.praxis import LayerNormLinear, LayerNormMLP, Linear from transformer_engine.jax.praxis import LayerNormLinear, LayerNormMLP, Linear
from transformer_engine.jax.praxis import MultiHeadAttention, RelativePositionBiases from transformer_engine.jax.praxis import MultiHeadAttention, RelativePositionBiases
from transformer_engine.jax.praxis import TransformerEngineBaseLayer, TransformerLayer, TransformerLayerType from transformer_engine.jax.praxis import TransformerEngineBaseLayer
from transformer_engine.jax.praxis import TransformerLayer, TransformerLayerType
from transformer_engine.jax.softmax import SoftmaxType from transformer_engine.jax.softmax import SoftmaxType
from utils import assert_allclose
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
...@@ -674,6 +676,8 @@ class MultiHeadAttnAttr: ...@@ -674,6 +676,8 @@ class MultiHeadAttnAttr:
LN_TYPE = 'layernorm_type' LN_TYPE = 'layernorm_type'
ATTN_MASK_TYPE = 'attn_mask_type' ATTN_MASK_TYPE = 'attn_mask_type'
ZERO_CEN = 'zero_centered_gamma' ZERO_CEN = 'zero_centered_gamma'
NUM_ATTN_HEADS = 'num_attention_heads'
NUM_GQA_GROUPS = 'num_gqa_groups'
ATTRS = [{ ATTRS = [{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
...@@ -704,6 +708,13 @@ class MultiHeadAttnAttr: ...@@ -704,6 +708,13 @@ class MultiHeadAttnAttr:
LN_TYPE: 'rmsnorm', LN_TYPE: 'rmsnorm',
ZERO_CEN: False, ZERO_CEN: False,
ATTN_MASK_TYPE: 'causal' ATTN_MASK_TYPE: 'causal'
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal'
}] }]
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Utility for the TE layer tests"""
import functools import functools
import math import math
...@@ -28,6 +29,9 @@ Initializer = Callable[[PRNGKey, Shape, DType], Array] ...@@ -28,6 +29,9 @@ Initializer = Callable[[PRNGKey, Shape, DType], Array]
def is_devices_enough(required): def is_devices_enough(required):
"""
Check if the available GPUs is enough
"""
return len(jax.devices()) >= required return len(jax.devices()) >= required
...@@ -121,9 +125,9 @@ def dot_product_attention(query: Array, ...@@ -121,9 +125,9 @@ def dot_product_attention(query: Array,
query: queries for calculating attention with shape of `[batch, q_length, query: queries for calculating attention with shape of `[batch, q_length,
num_heads, qk_depth_per_head]`. num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of `[batch, kv_length, key: keys for calculating attention with shape of `[batch, kv_length,
num_heads, qk_depth_per_head]`. num_gqa_groups, qk_depth_per_head]`.
value: values to be used in attention with shape of `[batch, kv_length, value: values to be used in attention with shape of `[batch, kv_length,
num_heads, v_depth_per_head]`. num_gqa_groups, v_depth_per_head]`.
bias: bias for the attention weights. This should be broadcastable to the bias: bias for the attention weights. This should be broadcastable to the
shape `[batch, num_heads, q_length, kv_length]` This can be used for shape `[batch, num_heads, q_length, kv_length]` This can be used for
incorporating causal masks, padding masks, proximity bias, etc. incorporating causal masks, padding masks, proximity bias, etc.
...@@ -141,21 +145,31 @@ def dot_product_attention(query: Array, ...@@ -141,21 +145,31 @@ def dot_product_attention(query: Array,
batch_dim = 1 if transpose_batch_sequence else 0 batch_dim = 1 if transpose_batch_sequence else 0
assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], ( assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
'q, k, v batch dims must match.') 'q, k, v batch dims must match.')
assert query.shape[-2] == key.shape[-2] == value.shape[-2], ('q, k, v num_heads must match.')
sequence_dim = 0 if transpose_batch_sequence else 1 sequence_dim = 0 if transpose_batch_sequence else 1
assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.' assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.'
# Casting logits and softmax computation for float32 for model stability. # Casting logits and softmax computation for float32 for model stability.
if float32_logits: if float32_logits:
query = query.astype(jnp.float32) query = query.astype(jnp.float32)
key = key.astype(jnp.float32) key = key.astype(jnp.float32)
# `attn_weights`: [batch, num_heads, q_length, kv_length] # `attn_weights`: [batch, num_heads, groups, q_length, kv_length]
h_q, h_kv = query.shape[-2], key.shape[-2]
assert (h_q % h_kv == 0) and (h_q >= h_kv)
group_size = h_q // h_kv
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
if transpose_batch_sequence: if transpose_batch_sequence:
attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key) attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
else: else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
# reshape back to normal DPA shape for bias/softmax/dropout
b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
# Apply attention bias: masking, dropout, proximity bias, etc. # Apply attention bias: masking, dropout, proximity bias, etc.
if bias is not None: if bias is not None:
...@@ -174,11 +188,13 @@ def dot_product_attention(query: Array, ...@@ -174,11 +188,13 @@ def dot_product_attention(query: Array,
multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
attn_weights = attn_weights * multiplier attn_weights = attn_weights * multiplier
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
# Take the linear combination of `value`. # Take the linear combination of `value`.
if transpose_batch_sequence: if transpose_batch_sequence:
return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value) return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
class DenseGeneral(nn.Module): class DenseGeneral(nn.Module):
...@@ -235,7 +251,8 @@ class DenseGeneral(nn.Module): ...@@ -235,7 +251,8 @@ class DenseGeneral(nn.Module):
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes('bias', bias = nn_partitioning.param_with_axes('bias',
self.bias_init, (self.features,), self.bias_init,
self.features,
self.dtype, self.dtype,
axes=self.bias_axes) axes=self.bias_axes)
else: else:
...@@ -332,6 +349,7 @@ class MultiHeadAttention(nn.Module): ...@@ -332,6 +349,7 @@ class MultiHeadAttention(nn.Module):
Attributes: Attributes:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads. should be divisible by the number of heads.
num_gqa_groups: number of kv attention heads
head_dim: dimension of each head. head_dim: dimension of each head.
dtype: the dtype of the computation. dtype: the dtype of the computation.
dropout_rate: dropout rate dropout_rate: dropout rate
...@@ -340,9 +358,10 @@ class MultiHeadAttention(nn.Module): ...@@ -340,9 +358,10 @@ class MultiHeadAttention(nn.Module):
numerical issues with bfloat16. numerical issues with bfloat16.
""" """
num_heads: int num_heads: int = 8
head_dim: int num_gqa_groups: int | None = None
transpose_batch_sequence: bool head_dim: int = 64
transpose_batch_sequence: bool = True
dtype: DType = jnp.float32 dtype: DType = jnp.float32
dropout_rate: float = 0. dropout_rate: float = 0.
kernel_init: Initializer = None kernel_init: Initializer = None
...@@ -354,6 +373,8 @@ class MultiHeadAttention(nn.Module): ...@@ -354,6 +373,8 @@ class MultiHeadAttention(nn.Module):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal') self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -393,18 +414,24 @@ class MultiHeadAttention(nn.Module): ...@@ -393,18 +414,24 @@ class MultiHeadAttention(nn.Module):
Returns: Returns:
output of shape `[batch, length, q_features]`. output of shape `[batch, length, q_features]`.
""" """
projection = functools.partial(DenseGeneral, q_projection = functools.partial(DenseGeneral,
axis=-1, axis=-1,
features=self.num_heads * self.head_dim, features=self.num_heads * self.head_dim,
kernel_axes=('embed', 'joined_kv'), kernel_axes=('embed', 'joined_kv'),
dtype=self.dtype) dtype=self.dtype)
kv_projection = functools.partial(DenseGeneral,
axis=-1,
features=self.num_gqa_groups * self.head_dim,
kernel_axes=('embed', 'joined_kv'),
dtype=self.dtype)
# NOTE: T5 does not explicitly rescale the attention logits by # NOTE: T5 does not explicitly rescale the attention logits by
# 1/sqrt(depth_kq)! This is folded into the initializers of the # 1/sqrt(depth_kq)! This is folded into the initializers of the
# linear transformations, which is equivalent under Adafactor # linear transformations, which is equivalent under Adafactor
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
query_init = lambda *args: self.kernel_init(*args) / ( # pylint: disable=unnecessary-lambda-assignment query_init = lambda *args: self.kernel_init(*args) / (depth_scaling
depth_scaling if self.scaled_query_init else 1.0) if self.scaled_query_init else 1.0)
# Project inputs_q to multi-headed q/k/v # Project inputs_q to multi-headed q/k/v
# dimensions are then [batch, length, num_heads, head_dim] # dimensions are then [batch, length, num_heads, head_dim]
...@@ -417,13 +444,17 @@ class MultiHeadAttention(nn.Module): ...@@ -417,13 +444,17 @@ class MultiHeadAttention(nn.Module):
v_shape = (shape[0], shape[1] // 3) v_shape = (shape[0], shape[1] // 3)
q_kernel = query_init(key, q_shape, dtype) q_kernel = query_init(key, q_shape, dtype)
k_kernel = self.kernel_init(key, k_shape, dtype) # pylint: disable=too-many-function-args k_kernel = self.kernel_init(key, k_shape, dtype)
v_kernel = self.kernel_init(key, v_shape, dtype) # pylint: disable=too-many-function-args v_kernel = self.kernel_init(key, v_shape, dtype)
return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype) return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype)
is_self_attn = (inputs_q is inputs_kv)
is_gqa = (self.num_heads != self.num_gqa_groups)
is_qkvpack = (is_self_attn and not is_gqa)
if self.fuse_qkv: if self.fuse_qkv:
if inputs_q is inputs_kv: if is_qkvpack:
qkv_proj = DenseGeneral(axis=-1, qkv_proj = DenseGeneral(axis=-1,
features=self.num_heads * self.head_dim * 3, features=self.num_heads * self.head_dim * 3,
kernel_axes=('embed', 'joined_kv'), kernel_axes=('embed', 'joined_kv'),
...@@ -436,24 +467,24 @@ class MultiHeadAttention(nn.Module): ...@@ -436,24 +467,24 @@ class MultiHeadAttention(nn.Module):
if self.scale_attn_logits: if self.scale_attn_logits:
query = query / depth_scaling query = query / depth_scaling
else: else:
query = projection(kernel_init=query_init, name='query')( \ query = q_projection(kernel_init=query_init, name='query')( \
(inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q) (inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q)
kv_proj = DenseGeneral(axis=-1, kv_proj = DenseGeneral(axis=-1,
features=self.num_heads * self.head_dim * 2, features=self.num_gqa_groups * self.head_dim * 2,
kernel_axes=('embed', 'joined_kv'), kernel_axes=('embed', 'joined_kv'),
kernel_init=self.kernel_init, kernel_init=self.kernel_init,
name='kv', name='kv',
dtype=self.dtype)(inputs_kv) dtype=self.dtype)(inputs_kv)
key, value = jnp.split(kv_proj, [self.num_heads * self.head_dim], axis=-1) key, value = jnp.split(kv_proj, [self.num_gqa_groups * self.head_dim], axis=-1)
else: else:
query = projection(kernel_init=query_init, name='query')( \ query = q_projection(kernel_init=query_init, name='query')( \
(inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q) (inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q)
key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv) key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv) value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim)) query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim)) key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim)) value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
query = nn_partitioning.with_sharding_constraint(query, query = nn_partitioning.with_sharding_constraint(query,
...@@ -476,7 +507,7 @@ class MultiHeadAttention(nn.Module): ...@@ -476,7 +507,7 @@ class MultiHeadAttention(nn.Module):
# fusion optimization. This also enables the "scatter via one-hot # fusion optimization. This also enables the "scatter via one-hot
# broadcast" trick, which means we do a one-hot broadcast instead of a # broadcast" trick, which means we do a one-hot broadcast instead of a
# scatter/gather operations, resulting in a 3-4x speedup in practice. # scatter/gather operations, resulting in a 3-4x speedup in practice.
swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3]) # pylint: disable=unnecessary-lambda-assignment swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3])
cached_key = self.variable('cache', 'cached_key', jnp.zeros, swap_dims(key.shape), cached_key = self.variable('cache', 'cached_key', jnp.zeros, swap_dims(key.shape),
key.dtype) key.dtype)
cached_value = self.variable('cache', 'cached_value', jnp.zeros, swap_dims(value.shape), cached_value = self.variable('cache', 'cached_value', jnp.zeros, swap_dims(value.shape),
...@@ -755,7 +786,8 @@ class RelativePositionBiases(nn.Module): ...@@ -755,7 +786,8 @@ class RelativePositionBiases(nn.Module):
class EncoderLayer(nn.Module): class EncoderLayer(nn.Module):
"""Transformer encoder layer.""" """Transformer encoder layer."""
relative_embedding: nn.Module = None relative_embedding: nn.Module = None
num_heads: int = 8 num_attention_heads: int = 8
num_gqa_groups: int | None = None
head_dim: int = 64 head_dim: int = 64
dropout_rate: float = 0.1 dropout_rate: float = 0.1
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
...@@ -773,6 +805,11 @@ class EncoderLayer(nn.Module): ...@@ -773,6 +805,11 @@ class EncoderLayer(nn.Module):
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False fuse_mlp_wi: bool = False
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact @nn.compact
def __call__(self, inputs, encoder_mask=None, deterministic=False): def __call__(self, inputs, encoder_mask=None, deterministic=False):
# Relative position embedding as attention biases. # Relative position embedding as attention biases.
...@@ -782,7 +819,7 @@ class EncoderLayer(nn.Module): ...@@ -782,7 +819,7 @@ class EncoderLayer(nn.Module):
if self.relative_embedding is None: if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32, rel_emb = RelativePositionBiases(num_buckets=32,
max_distance=128, max_distance=128,
num_heads=self.num_heads, num_heads=self.num_attention_heads,
dtype=self.dtype, dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling( embedding_init=nn.initializers.variance_scaling(
1.0, 'fan_avg', 'uniform'), 1.0, 'fan_avg', 'uniform'),
...@@ -807,7 +844,8 @@ class EncoderLayer(nn.Module): ...@@ -807,7 +844,8 @@ class EncoderLayer(nn.Module):
x = inputs x = inputs
# [batch, length, emb_dim] -> [batch, length, emb_dim] # [batch, length, emb_dim] -> [batch, length, emb_dim]
x = MultiHeadAttention(num_heads=self.num_heads, x = MultiHeadAttention(num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype, dtype=self.dtype,
head_dim=self.head_dim, head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
...@@ -868,7 +906,8 @@ class EncoderLayer(nn.Module): ...@@ -868,7 +906,8 @@ class EncoderLayer(nn.Module):
class DecoderLayer(nn.Module): class DecoderLayer(nn.Module):
"""Transformer decoder layer that attends to the encoder.""" """Transformer decoder layer that attends to the encoder."""
relative_embedding: nn.Module = None relative_embedding: nn.Module = None
num_heads: int = 8 num_attention_heads: int = 8
num_gqa_groups: int | None = None
head_dim: int = 64 head_dim: int = 64
dropout_rate: float = 0.1 dropout_rate: float = 0.1
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
...@@ -886,6 +925,11 @@ class DecoderLayer(nn.Module): ...@@ -886,6 +925,11 @@ class DecoderLayer(nn.Module):
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False fuse_mlp_wi: bool = False
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact @nn.compact
def __call__(self, def __call__(self,
inputs, inputs,
...@@ -903,7 +947,7 @@ class DecoderLayer(nn.Module): ...@@ -903,7 +947,7 @@ class DecoderLayer(nn.Module):
if self.relative_embedding is None: if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32, rel_emb = RelativePositionBiases(num_buckets=32,
max_distance=128, max_distance=128,
num_heads=self.num_heads, num_heads=self.num_attention_heads,
dtype=self.dtype, dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling( embedding_init=nn.initializers.variance_scaling(
1.0, 'fan_avg', 'uniform'), 1.0, 'fan_avg', 'uniform'),
...@@ -928,7 +972,8 @@ class DecoderLayer(nn.Module): ...@@ -928,7 +972,8 @@ class DecoderLayer(nn.Module):
x = inputs x = inputs
# Self-attention block # Self-attention block
x = MultiHeadAttention(num_heads=self.num_heads, x = MultiHeadAttention(num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype, dtype=self.dtype,
head_dim=self.head_dim, head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
...@@ -960,7 +1005,8 @@ class DecoderLayer(nn.Module): ...@@ -960,7 +1005,8 @@ class DecoderLayer(nn.Module):
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
residual = y residual = y
y = MultiHeadAttention(num_heads=self.num_heads, y = MultiHeadAttention(num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype, dtype=self.dtype,
head_dim=self.head_dim, head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
...@@ -1012,6 +1058,9 @@ class DecoderLayer(nn.Module): ...@@ -1012,6 +1058,9 @@ class DecoderLayer(nn.Module):
def make_causal_mask(batch, seqlen, dtype=jnp.uint8): def make_causal_mask(batch, seqlen, dtype=jnp.uint8):
"""
Generate causal mask
"""
shape = (batch, seqlen) shape = (batch, seqlen)
idxs = jnp.broadcast_to(jnp.arange(shape[-1], dtype=jnp.int32), shape) idxs = jnp.broadcast_to(jnp.arange(shape[-1], dtype=jnp.int32), shape)
...@@ -1022,6 +1071,9 @@ def make_causal_mask(batch, seqlen, dtype=jnp.uint8): ...@@ -1022,6 +1071,9 @@ def make_causal_mask(batch, seqlen, dtype=jnp.uint8):
def make_self_mask(batch, seqlen, dtype=jnp.uint8): def make_self_mask(batch, seqlen, dtype=jnp.uint8):
"""
Generate attention mask
"""
shape = (batch, seqlen) shape = (batch, seqlen)
mask = jnp.ones((*shape, shape[-1])) mask = jnp.ones((*shape, shape[-1]))
mask = jnp.expand_dims(mask, axis=-3) mask = jnp.expand_dims(mask, axis=-3)
...@@ -1057,7 +1109,7 @@ def assert_allclose( ...@@ -1057,7 +1109,7 @@ def assert_allclose(
dtype = actual.dtype dtype = actual.dtype
# Determine tolerances # Determine tolerances
tols = dict() tols = {}
if rtol is None or atol is None: if rtol is None or atol is None:
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if rtol is not None: if rtol is not None:
......
...@@ -573,14 +573,14 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -573,14 +573,14 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const DType QKV_type = input_QKV->data.dtype; const auto QKV_type = input_QKV->data.dtype;
void *devPtrQKV = input_QKV->data.dptr; void *devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = 2 * num_attn_heads * head_dim; stride = typeToSize(QKV_type) * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = 2 * head_dim; stride = typeToSize(QKV_type) * head_dim;
} }
void *devPtrQ = static_cast<void *>(devPtrQKV); void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride); void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
...@@ -677,14 +677,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea ...@@ -677,14 +677,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const auto QKV_type = input_QKV->data.dtype;
void *devPtrQKV = input_QKV->data.dptr; void *devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = 2 * num_attn_heads * head_dim; stride = typeToSize(QKV_type) * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = 2 * head_dim; stride = typeToSize(QKV_type) * head_dim;
} }
void *devPtrQ = devPtrQKV; void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride); void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
...@@ -712,7 +713,6 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea ...@@ -712,7 +713,6 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
void* devPtrDropoutOffset = reinterpret_cast<void *>( void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1); reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const auto qkv_type = input_QKV->data.dtype;
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_attn_heads, fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_attn_heads,
...@@ -723,7 +723,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea ...@@ -723,7 +723,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlens, devPtrCuSeqlens, devPtrCuSeqlens, devPtrCuSeqlens,
get_cudnn_fe_dtype(qkv_type), workspace->data.dptr, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle); &workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
...@@ -750,15 +750,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -750,15 +750,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype; const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr; void *devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr; void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = 2 * num_attn_heads * head_dim; stride = typeToSize(QKV_type) * num_gqa_groups * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = 2 * head_dim; stride = typeToSize(QKV_type) * head_dim;
} }
void *devPtrK = devPtrKV; void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride); void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
...@@ -860,15 +860,14 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -860,15 +860,14 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
using namespace transformer_engine; using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype; const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr; void *devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr; void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = 2 * num_attn_heads * head_dim; stride = typeToSize(QKV_type) * num_gqa_groups * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = 2 * head_dim; stride = typeToSize(QKV_type) * head_dim;
} }
void *devPtrK = devPtrKV; void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride); void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
...@@ -935,7 +934,7 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -935,7 +934,7 @@ void fused_attn_arbitrary_seqlen_fwd(
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype; const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr; void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_K->data.dptr; void *devPtrK = input_K->data.dptr;
void *devPtrV = input_V->data.dptr; void *devPtrV = input_V->data.dptr;
......
...@@ -1651,13 +1651,10 @@ class FusedAttnHelper: ...@@ -1651,13 +1651,10 @@ class FusedAttnHelper:
def get_fused_attn_backend(self): def get_fused_attn_backend(self):
"""Get the fused attention kernel backend""" """Get the fused attention kernel backend"""
return transformer_engine_jax.get_fused_attn_backend(jax_dtype_to_te_dtype(self.q_type), return transformer_engine_jax.get_fused_attn_backend(
jax_dtype_to_te_dtype(self.kv_type), jax_dtype_to_te_dtype(self.q_type), jax_dtype_to_te_dtype(self.kv_type),
self.qkv_layout, self.attn_bias_type, self.qkv_layout, self.attn_bias_type, self.attn_mask_type, self.dropout_probability,
self.attn_mask_type, self.num_heads_q, self.num_heads_kv, self.max_seqlen_q, self.max_seqlen_kv,
self.dropout_probability,
self.num_heads_q, self.num_heads_kv,
self.max_seqlen_q, self.max_seqlen_kv,
self.head_dim) self.head_dim)
...@@ -1701,12 +1698,11 @@ class _FusedAttnRNGStateChecker: ...@@ -1701,12 +1698,11 @@ class _FusedAttnRNGStateChecker:
return seed return seed
def generate_cu_seqlen(mask): def generate_cu_seqlen(actual_seqlen):
""" """
Generating cumsum seqlen for a batch Generating cumsum seqlen for a batch
""" """
seqlen = jnp.sum(mask == 0, axis=(-1, -2), dtype=jnp.int32) cu_seqlen = jnp.cumsum(actual_seqlen)
cu_seqlen = jnp.cumsum(seqlen)
cu_seqlen = jnp.hstack((0, cu_seqlen)) cu_seqlen = jnp.hstack((0, cu_seqlen))
return cu_seqlen return cu_seqlen
...@@ -1722,13 +1718,13 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -1722,13 +1718,13 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract(qkv_aval, bias_aval, mask_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, def abstract(qkv_aval, bias_aval, seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training): attn_mask_type, scaling_factor, dropout_probability, is_training):
""" """
Self fused attention fwd abstract Self fused attention fwd abstract
""" """
# outer_primitve is squeezed_mask, inner_primitive is cu_seqlen # outer_primitve is seqlen, inner_primitive is cu_seqlen
del mask_or_cu_seqlen_aval, scaling_factor, is_training del seqlen_or_cu_seqlen_aval, scaling_factor, is_training
qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype) qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
*batch_shape, max_seqlen, nqkv, num_head, head_dim = qkv_aval.shape *batch_shape, max_seqlen, nqkv, num_head, head_dim = qkv_aval.shape
assert nqkv == 3 assert nqkv == 3
...@@ -1781,19 +1777,20 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -1781,19 +1777,20 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability, batch, num_head, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training) dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out return out
@staticmethod @staticmethod
def impl(qkv, bias, squeezed_mask, seed, attn_bias_type, attn_mask_type, scaling_factor, def impl(qkv, bias, seqlen, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training): dropout_probability, is_training):
assert SelfFusedAttnFwdPrimitive.inner_primitive is not None assert SelfFusedAttnFwdPrimitive.inner_primitive is not None
cu_seqlen = generate_cu_seqlen(squeezed_mask) cu_seqlen = generate_cu_seqlen(seqlen)
output, softmax_aux, rng_state = SelfFusedAttnFwdPrimitive.inner_primitive.bind( output, softmax_aux, rng_state = SelfFusedAttnFwdPrimitive.inner_primitive.bind(
qkv, qkv,
...@@ -1859,10 +1856,9 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -1859,10 +1856,9 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
register_primitive(SelfFusedAttnFwdPrimitive) register_primitive(SelfFusedAttnFwdPrimitive)
def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, squeezed_mask: jnp.ndarray, def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, seqlen: jnp.ndarray, seed: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float, scaling_factor: float, dropout_probability: float, is_training: bool):
dropout_probability: float, is_training: bool):
""" """
Wrapper for TE self fused attention fwd Wrapper for TE self fused attention fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
...@@ -1875,7 +1871,7 @@ def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, squeezed_mask: jnp. ...@@ -1875,7 +1871,7 @@ def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, squeezed_mask: jnp.
bias = jnp.zeros(0, dtype=qkv.dtype) bias = jnp.zeros(0, dtype=qkv.dtype)
return SelfFusedAttnFwdPrimitive.outer_primitive.bind(qkv, return SelfFusedAttnFwdPrimitive.outer_primitive.bind(qkv,
bias, bias,
squeezed_mask, seqlen,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
...@@ -1896,14 +1892,14 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive): ...@@ -1896,14 +1892,14 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
@staticmethod @staticmethod
def abstract(qkv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, doutput_aval, def abstract(qkv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, doutput_aval,
mask_or_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor, seqlen_or_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training): dropout_probability, is_training):
""" """
Self fused attention bwd abstract Self fused attention bwd abstract
""" """
del softmax_aux_aval, rng_state_aval del softmax_aux_aval, rng_state_aval
# outer_primitve is squeezed_mask, inner_primitive is cu_seqlen # outer_primitve is seqlen, inner_primitive is cu_seqlen
del mask_or_cu_seqlen_aval, attn_bias_type, attn_mask_type del seqlen_or_cu_seqlen_aval, attn_bias_type, attn_mask_type
del scaling_factor, dropout_probability, is_training del scaling_factor, dropout_probability, is_training
qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype) qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
...@@ -1934,19 +1930,20 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive): ...@@ -1934,19 +1930,20 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability, batch, num_head, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training) dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out return out
@staticmethod @staticmethod
def impl(qkv, bias, softmax_aux, rng_state, output, doutput, squeezed_mask, attn_bias_type, def impl(qkv, bias, softmax_aux, rng_state, output, doutput, seqlen, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training): attn_mask_type, scaling_factor, dropout_probability, is_training):
assert SelfFusedAttnBwdPrimitive.inner_primitive is not None assert SelfFusedAttnBwdPrimitive.inner_primitive is not None
cu_seqlen = generate_cu_seqlen(squeezed_mask) cu_seqlen = generate_cu_seqlen(seqlen)
dqkv, dbias = SelfFusedAttnBwdPrimitive.inner_primitive.bind( dqkv, dbias = SelfFusedAttnBwdPrimitive.inner_primitive.bind(
qkv, qkv,
...@@ -2029,7 +2026,7 @@ register_primitive(SelfFusedAttnBwdPrimitive) ...@@ -2029,7 +2026,7 @@ register_primitive(SelfFusedAttnBwdPrimitive)
def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray, def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray,
rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray,
squeezed_mask: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float, attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool): dropout_probability: float, is_training: bool):
""" """
...@@ -2045,7 +2042,7 @@ def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.nd ...@@ -2045,7 +2042,7 @@ def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.nd
rng_state, rng_state,
output, output,
doutput, doutput,
squeezed_mask, seqlen,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
...@@ -2064,13 +2061,13 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2064,13 +2061,13 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract(q_aval, kv_aval, bias_aval, q_mask_or_cu_seqlen_aval, kv_mask_or_cu_seqlen_aval, def abstract(q_aval, kv_aval, bias_aval, q_seqlen_or_cu_seqlen_aval,
seed_aval, *, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type,
is_training): scaling_factor, dropout_probability, is_training):
""" """
Cross fused attention fwd abstract Cross fused attention fwd abstract
""" """
# outer_primitve is squeezed_mask, inner_primitive is cu_seqlen # outer_primitve is seqlen, inner_primitive is cu_seqlen
del scaling_factor, is_training del scaling_factor, is_training
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
...@@ -2083,18 +2080,17 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2083,18 +2080,17 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
assert q_dtype == kv_dtype == bias_dtype assert q_dtype == kv_dtype == bias_dtype
assert q_batch_shape == kv_batch_shape assert q_batch_shape == kv_batch_shape
assert q_num_head == kv_num_head
assert q_head_dim == kv_head_dim assert q_head_dim == kv_head_dim
assert nkv == 2 assert nkv == 2
assert q_mask_or_cu_seqlen_aval.dtype == kv_mask_or_cu_seqlen_aval.dtype assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
output_shape = q_aval.shape output_shape = q_aval.shape
output_dtype = q_dtype output_dtype = q_dtype
backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD, backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
attn_bias_type, attn_mask_type, dropout_probability, attn_bias_type, attn_mask_type, dropout_probability, q_num_head,
q_num_head, kv_num_head, kv_num_head, q_max_seqlen, kv_max_seqlen,
q_max_seqlen, kv_max_seqlen, q_head_dim).get_fused_attn_backend() q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, kv_max_seqlen) softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, kv_max_seqlen)
...@@ -2128,7 +2124,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2128,7 +2124,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
*batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape *batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
batch = reduce(operator.mul, batch_shape) batch = reduce(operator.mul, batch_shape)
kv_max_seqlen = kv_aval.shape[-4] kv_max_seqlen, kv_num_head = kv_aval.shape[-4], kv_aval.shape[-2]
operands = [q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed] operands = [q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed]
operand_shapes = map(lambda x: x.type.shape, operands) operand_shapes = map(lambda x: x.type.shape, operands)
...@@ -2139,7 +2135,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2139,7 +2135,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, batch, num_head, kv_num_head, q_max_seqlen, kv_max_seqlen, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training) jax_dtype_to_te_dtype(q_aval.dtype), is_training)
...@@ -2148,12 +2144,12 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2148,12 +2144,12 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
return out return out
@staticmethod @staticmethod
def impl(q, kv, bias, q_squeezed_mask, kv_squeezed_mask, seed, attn_bias_type, attn_mask_type, def impl(q, kv, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type, scaling_factor,
scaling_factor, dropout_probability, is_training): dropout_probability, is_training):
assert CrossFusedAttnFwdPrimitive.inner_primitive is not None assert CrossFusedAttnFwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask) q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
output, softmax_aux, rng_state = CrossFusedAttnFwdPrimitive.inner_primitive.bind( output, softmax_aux, rng_state = CrossFusedAttnFwdPrimitive.inner_primitive.bind(
q, q,
...@@ -2224,9 +2220,8 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2224,9 +2220,8 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
register_primitive(CrossFusedAttnFwdPrimitive) register_primitive(CrossFusedAttnFwdPrimitive)
def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, q_seqlen: jnp.ndarray,
q_squeezed_mask: jnp.ndarray, kv_squeezed_mask: jnp.ndarray, kv_seqlen: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float, attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool): dropout_probability: float, is_training: bool):
""" """
...@@ -2243,8 +2238,8 @@ def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, ...@@ -2243,8 +2238,8 @@ def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
return CrossFusedAttnFwdPrimitive.outer_primitive.bind(q, return CrossFusedAttnFwdPrimitive.outer_primitive.bind(q,
kv, kv,
bias, bias,
q_squeezed_mask, q_seqlen,
kv_squeezed_mask, kv_seqlen,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
...@@ -2296,7 +2291,7 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2296,7 +2291,7 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
*batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape *batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
batch = reduce(operator.mul, batch_shape) batch = reduce(operator.mul, batch_shape)
kv_max_seqlen = kv_aval.shape[-4] kv_max_seqlen, kv_num_head = kv_aval.shape[-4], kv_aval.shape[-2]
operands = [q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen] operands = [q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen]
operand_shapes = map(lambda x: x.type.shape, operands) operand_shapes = map(lambda x: x.type.shape, operands)
...@@ -2310,7 +2305,7 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2310,7 +2305,7 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
# the dropout elements are encoded in the forward auxiliary tensor # the dropout elements are encoded in the forward auxiliary tensor
# so seed is not needed in backward # so seed is not needed in backward
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, batch, num_head, kv_num_head, q_max_seqlen, kv_max_seqlen, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training) jax_dtype_to_te_dtype(q_aval.dtype), is_training)
...@@ -2319,13 +2314,12 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2319,13 +2314,12 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
return out return out
@staticmethod @staticmethod
def impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_squeezed_mask, def impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen,
kv_squeezed_mask, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training):
is_training):
assert CrossFusedAttnBwdPrimitive.inner_primitive is not None assert CrossFusedAttnBwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask) q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
dq, dkv, dbias = CrossFusedAttnBwdPrimitive.inner_primitive.bind( dq, dkv, dbias = CrossFusedAttnBwdPrimitive.inner_primitive.bind(
q, q,
...@@ -2417,10 +2411,9 @@ register_primitive(CrossFusedAttnBwdPrimitive) ...@@ -2417,10 +2411,9 @@ register_primitive(CrossFusedAttnBwdPrimitive)
def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray, softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray,
doutput: jnp.ndarray, q_squeezed_mask: jnp.ndarray, doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray,
kv_squeezed_mask: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float, scaling_factor: float, dropout_probability: float, is_training: bool):
dropout_probability: float, is_training: bool):
""" """
Wrapper for TE cross fused attention bwd Wrapper for TE cross fused attention bwd
Return the gradients of cross fused attention with packed kv input Return the gradients of cross fused attention with packed kv input
...@@ -2435,8 +2428,8 @@ def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, ...@@ -2435,8 +2428,8 @@ def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
rng_state, rng_state,
output, output,
doutput, doutput,
q_squeezed_mask, q_seqlen,
kv_squeezed_mask, kv_seqlen,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
......
...@@ -82,12 +82,12 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, ...@@ -82,12 +82,12 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch,
} }
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, size_t batch, size_t num_head, size_t num_gqa_groups, size_t q_max_seqlen, size_t kv_max_seqlen,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, DType dtype, bool is_training) { NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{batch, num_head, q_max_seqlen, kv_max_seqlen, return PackOpaque(CustomCallFusedAttnDescriptor{
head_dim, scaling_factor, dropout_probability, batch, num_head, num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, scaling_factor,
bias_type, mask_type, dtype, is_training}); dropout_probability, bias_type, mask_type, dtype, is_training});
} }
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream, void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
...@@ -745,8 +745,8 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, ...@@ -745,8 +745,8 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
size_t head_dim) { size_t head_dim) {
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_num_heads, kv_num_heads, mask_type, dropout_probability, q_num_heads, kv_num_heads, q_max_seqlen, kv_max_seqlen,
q_max_seqlen, kv_max_seqlen, head_dim); head_dim);
return backend; return backend;
} }
...@@ -768,6 +768,7 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu ...@@ -768,6 +768,7 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
auto batch = descriptor.batch; auto batch = descriptor.batch;
auto num_head = descriptor.num_head; auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen; auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen; auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim; auto head_dim = descriptor.head_dim;
...@@ -779,6 +780,9 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu ...@@ -779,6 +780,9 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention."); "q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
NVTE_CHECK(num_head == num_gqa_groups,
"num_head should be equal to num_gqa_groups in the qkvpacked attention");
auto dtype = descriptor.dtype; auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim}; auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen}; auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
...@@ -799,10 +803,10 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu ...@@ -799,10 +803,10 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
// aux tensors // aux tensors
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64); auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend( auto backend =
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type, nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype),
mask_type, dropout_probability, num_head, num_head, qkv_layout, bias_type, mask_type, dropout_probability, num_head,
q_max_seqlen, kv_max_seqlen, head_dim); num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
...@@ -853,6 +857,7 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -853,6 +857,7 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
auto batch = descriptor.batch; auto batch = descriptor.batch;
auto num_head = descriptor.num_head; auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen; auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen; auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim; auto head_dim = descriptor.head_dim;
...@@ -864,6 +869,9 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -864,6 +869,9 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention."); "q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
NVTE_CHECK(num_head == num_gqa_groups,
"num_head should be equal to num_gqa_groups in the qkvpacked attention");
auto dtype = descriptor.dtype; auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim}; auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim};
auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}; auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
...@@ -941,6 +949,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -941,6 +949,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
auto batch = descriptor.batch; auto batch = descriptor.batch;
auto num_head = descriptor.num_head; auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen; auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen; auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim; auto head_dim = descriptor.head_dim;
...@@ -951,7 +960,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -951,7 +960,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
auto dtype = descriptor.dtype; auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}; auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_head, head_dim}; auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen}; auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
// input tensors // input tensors
...@@ -976,10 +985,10 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -976,10 +985,10 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64); auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend( auto backend =
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type, nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype),
mask_type, dropout_probability, num_head, num_head, qkv_layout, bias_type, mask_type, dropout_probability, num_head,
q_max_seqlen, kv_max_seqlen, head_dim); num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
...@@ -1035,6 +1044,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa ...@@ -1035,6 +1044,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
auto batch = descriptor.batch; auto batch = descriptor.batch;
auto num_head = descriptor.num_head; auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen; auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen; auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim; auto head_dim = descriptor.head_dim;
...@@ -1045,7 +1055,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa ...@@ -1045,7 +1055,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
auto dtype = descriptor.dtype; auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}; auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_head, head_dim}; auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}; auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen}; auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
......
...@@ -98,6 +98,7 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, ...@@ -98,6 +98,7 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch,
struct CustomCallFusedAttnDescriptor { struct CustomCallFusedAttnDescriptor {
size_t batch; size_t batch;
size_t num_head; size_t num_head;
size_t num_gqa_groups;
size_t q_max_seqlen; size_t q_max_seqlen;
size_t kv_max_seqlen; size_t kv_max_seqlen;
size_t head_dim; size_t head_dim;
...@@ -110,8 +111,8 @@ struct CustomCallFusedAttnDescriptor { ...@@ -110,8 +111,8 @@ struct CustomCallFusedAttnDescriptor {
}; };
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, size_t batch, size_t num_head, size_t num_gqa_groups, size_t q_max_seqlen, size_t kv_max_seqlen,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, DType dtype, bool is_training); NVTE_Mask_Type mask_type, DType dtype, bool is_training);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
......
...@@ -16,7 +16,6 @@ import jax.numpy as jnp ...@@ -16,7 +16,6 @@ import jax.numpy as jnp
import numpy as np import numpy as np
from flax import linen as nn from flax import linen as nn
from flax.linen import partitioning as nn_partitioning from flax.linen import partitioning as nn_partitioning
from jax import dtypes
from jax import nn as jax_nn from jax import nn as jax_nn
from jax import random as jax_random from jax import random as jax_random
from jax import lax, vmap from jax import lax, vmap
...@@ -198,22 +197,31 @@ def core_attention(query: Array, ...@@ -198,22 +197,31 @@ def core_attention(query: Array,
batch_dim = 1 if transpose_batch_sequence else 0 batch_dim = 1 if transpose_batch_sequence else 0
assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], ( assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
'q, k, v batch dims must match.') 'q, k, v batch dims must match.')
assert query.shape[-2] == key.shape[-2] == value.shape[-2], ('q, k, v num_heads must match.')
sequence_dim = 0 if transpose_batch_sequence else 1 sequence_dim = 0 if transpose_batch_sequence else 1
assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.' assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.'
if float32_logits: if float32_logits:
query = query.astype(jnp.float32) query = query.astype(jnp.float32)
key = key.astype(jnp.float32) key = key.astype(jnp.float32)
h_q, h_kv = query.shape[-2], key.shape[-2]
assert (h_q % h_kv == 0) and (h_q >= h_kv)
group_size = h_q // h_kv
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
if transpose_batch_sequence: if transpose_batch_sequence:
attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key) attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
else: else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
attn_weights = checkpoint_name(attn_weights, 'logits') attn_weights = checkpoint_name(attn_weights, 'logits')
b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
attn_weights = _with_sharding_constraint(attn_weights, attn_weights = _with_sharding_constraint(attn_weights,
(BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES)) (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))
...@@ -229,6 +237,8 @@ def core_attention(query: Array, ...@@ -229,6 +237,8 @@ def core_attention(query: Array,
attn_weights = Softmax(softmax_type=softmax_type, attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype) scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype)
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
if not deterministic and dropout_rate > 0.: if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate keep_prob = 1.0 - dropout_rate
dropout_shape = list(attn_weights.shape) dropout_shape = list(attn_weights.shape)
...@@ -238,9 +248,9 @@ def core_attention(query: Array, ...@@ -238,9 +248,9 @@ def core_attention(query: Array,
attn_weights = attn_weights * multiplier attn_weights = attn_weights * multiplier
if transpose_batch_sequence: if transpose_batch_sequence:
return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value) return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
...@@ -262,6 +272,14 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -262,6 +272,14 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
The hidden dimension of each attention head. The hidden dimension of each attention head.
num_heads : int num_heads : int
The number of attention heads The number of attention heads
num_gqa_groups : int, default = `None`
Number of GQA groups. When `None` is present, it is equal to num_heads.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
dropout_rate : float, default = 0.0 dropout_rate : float, default = 0.0
Dropout probability for the dropout op during multi-head attention. Dropout probability for the dropout op during multi-head attention.
dropout_rng_name: str, default = 'dropout' dropout_rng_name: str, default = 'dropout'
...@@ -321,6 +339,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -321,6 +339,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
head_dim: int head_dim: int
num_heads: int num_heads: int
num_gqa_groups: int | None = None
dropout_rate: float = 0. dropout_rate: float = 0.
dropout_rng_name: str = 'dropout' dropout_rng_name: str = 'dropout'
layernorm_type: str = "layernorm" layernorm_type: str = "layernorm"
...@@ -342,6 +361,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -342,6 +361,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal') self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_heads
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -428,30 +449,22 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -428,30 +449,22 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
"supported attn_mask_type = {'causal', 'padding'}") "supported attn_mask_type = {'causal', 'padding'}")
is_self_attn = (inputs_q is inputs_kv) is_self_attn = (inputs_q is inputs_kv)
is_gqa = (self.num_heads != self.num_gqa_groups)
is_qkvpack = (is_self_attn and not is_gqa)
qkv_layout = QKVLayout.BS3HD if is_self_attn else QKVLayout.BSHD_BS2HD qkv_layout = QKVLayout.BS3HD if is_self_attn else QKVLayout.BSHD_BS2HD
attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type) attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
canonicalize_dtype = dtypes.canonicalize_dtype(self.dtype)
q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1] q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1]
kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1] kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1]
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0")) enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))
def _check_seqlen(seqlen):
return seqlen % 64 == 0
def _check_head_dim(head_dim):
return head_dim in [64, 128]
has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout, has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout,
attn_bias_type, attn_mask_type, attn_bias_type, attn_mask_type,
self.dropout_rate, self.dropout_rate, self.num_heads,
self.num_heads, self.num_heads, self.num_gqa_groups, q_seqlen,
q_seqlen, kv_seqlen, self.head_dim) kv_seqlen, self.head_dim)
use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \ use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \
_check_seqlen(q_seqlen) and _check_seqlen(kv_seqlen) and \
_check_head_dim(self.head_dim) and \
has_fused_attn_kernel and \ has_fused_attn_kernel and \
enable_fused_attn enable_fused_attn
...@@ -464,17 +477,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -464,17 +477,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
f"but got {self.transpose_batch_sequence}, " f"but got {self.transpose_batch_sequence}, "
if not self.fuse_qkv: if not self.fuse_qkv:
reason += f"fuse_qkv=True is required but got {self.fuse_qkv}, " reason += f"fuse_qkv=True is required but got {self.fuse_qkv}, "
if canonicalize_dtype not in [jnp.bfloat16, jnp.float16]:
reason += f"dtype in [BF16, FP16] is required " \
f"but got dtype={canonicalize_dtype}, "
if not _check_seqlen(q_seqlen):
reason += f"q_seqlen % 64 == 0 is required " \
f"but got {q_seqlen=}, "
if not _check_seqlen(kv_seqlen):
reason += f"kv_seqlen % 64 == 0 is required " \
f"but got {kv_seqlen=}, "
if not _check_head_dim(self.head_dim):
reason += f"head_dim should be 64 or 128 but got {self.head_dim}, "
if not has_fused_attn_kernel: if not has_fused_attn_kernel:
reason += "no fused attention kernel is available, " reason += "no fused attention kernel is available, "
...@@ -484,7 +486,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -484,7 +486,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
residual = inputs_q residual = inputs_q
if self.fuse_qkv: if self.fuse_qkv:
if is_self_attn: if is_qkvpack:
qkv_proj, ln_out = LayerNormDenseGeneral( qkv_proj, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm, enable_layernorm=not self.output_layernorm,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
...@@ -515,7 +517,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -515,7 +517,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
axis=-1, axis=-1,
features=self.num_heads * self.head_dim, features=self.num_heads * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm, return_layernorm_output=(self.apply_residual_connection_post_layernorm
or is_self_attn),
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes=(W_FSDP_AXES, W_TP_AXES), kernel_axes=(W_FSDP_AXES, W_TP_AXES),
...@@ -525,8 +528,13 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -525,8 +528,13 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
dtype=self.dtype, dtype=self.dtype,
kernel_init=query_init, kernel_init=query_init,
name='query')(inputs_q) name='query')(inputs_q)
if is_self_attn:
assert ln_out is not None
inputs_kv = ln_out
kv_proj = DenseGeneral(axis=-1, kv_proj = DenseGeneral(axis=-1,
features=(2, self.num_heads * self.head_dim), features=(2, self.num_gqa_groups * self.head_dim),
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_init=kv_init, kernel_init=kv_init,
...@@ -542,7 +550,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -542,7 +550,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kv_projection = functools.partial( kv_projection = functools.partial(
DenseGeneral, DenseGeneral,
axis=-1, axis=-1,
features=self.num_heads * self.head_dim, features=self.num_gqa_groups * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_TP_AXES), kernel_axes=(W_FSDP_AXES, W_TP_AXES),
use_bias=self.use_bias, use_bias=self.use_bias,
...@@ -583,9 +591,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -583,9 +591,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
query = checkpoint_name(query, 'query_proj') query = checkpoint_name(query, 'query_proj')
key = checkpoint_name(key, 'key_proj') key = checkpoint_name(key, 'key_proj')
value = checkpoint_name(value, 'value_proj') value = checkpoint_name(value, 'value_proj')
query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim)) query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim)) key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim)) value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
qkv_sharding_constraint = \ qkv_sharding_constraint = \
(SEQLEN_AXES, BATCH_AXES, HEAD_AXES, HIDDEN_AXES) \ (SEQLEN_AXES, BATCH_AXES, HEAD_AXES, HIDDEN_AXES) \
if self.transpose_batch_sequence \ if self.transpose_batch_sequence \
...@@ -650,7 +658,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -650,7 +658,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
# ensure the old key never used # ensure the old key never used
del dropout_rng del dropout_rng
if is_self_attn: if is_qkvpack:
qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim)) qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES, qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
HIDDEN_AXES) HIDDEN_AXES)
...@@ -667,7 +675,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -667,7 +675,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
else: else:
assert bias is None assert bias is None
query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim)) query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim))
kv_proj = kv_proj.reshape((*kv_proj.shape[:-1], self.num_heads, self.head_dim)) kv_proj = kv_proj.reshape((*kv_proj.shape[:-1], self.num_gqa_groups, self.head_dim))
q_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES) q_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
kv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES, kv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
HIDDEN_AXES) HIDDEN_AXES)
...@@ -865,6 +873,14 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -865,6 +873,14 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Intermediate size to which input samples are projected. Intermediate size to which input samples are projected.
num_attention_heads: int, default = 8 num_attention_heads: int, default = 8
Number of attention heads in the transformer layer. Number of attention heads in the transformer layer.
num_gqa_groups : int, default = `None`
Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm' layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
Indicate the type of layer normalization. Indicate the type of layer normalization.
layernorm_epsilon: float, default = 1e-6 layernorm_epsilon: float, default = 1e-6
...@@ -961,6 +977,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -961,6 +977,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
hidden_size: int = 512 hidden_size: int = 512
mlp_hidden_size: int = 2048 mlp_hidden_size: int = 2048
num_attention_heads: int = 8 num_attention_heads: int = 8
num_gqa_groups: int | None = None
layernorm_type: str = 'layernorm' layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6 layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
...@@ -995,6 +1012,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -995,6 +1012,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
if self.mlp_kernel_init is None: if self.mlp_kernel_init is None:
self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in',
'truncated_normal') 'truncated_normal')
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -1091,6 +1110,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1091,6 +1110,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
dtype=self.dtype, dtype=self.dtype,
head_dim=head_dim, head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.attention_dropout, dropout_rate=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name, dropout_rng_name=self.dropout_rng_name,
...@@ -1141,6 +1161,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1141,6 +1161,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
dtype=self.dtype, dtype=self.dtype,
head_dim=head_dim, head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.attention_dropout, dropout_rate=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name, dropout_rng_name=self.dropout_rng_name,
......
...@@ -40,8 +40,8 @@ class QKVLayout(Enum): ...@@ -40,8 +40,8 @@ class QKVLayout(Enum):
def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type, def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type,
dropout_probability, num_heads_q, num_heads_kv, dropout_probability, num_heads_q, num_heads_kv, max_seqlen_q,
max_seqlen_q, max_seqlen_kv, head_dim): max_seqlen_kv, head_dim):
""" """
To check whether the fused attention kernel is available To check whether the fused attention kernel is available
""" """
...@@ -83,10 +83,11 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda ...@@ -83,10 +83,11 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda
seed: jnp.ndarray, attn_bias_type: AttnBiasType, seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float, attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool): dropout_probability: float, is_training: bool):
squeezed_mask = mask[..., 0] mask = jnp.logical_not(mask)
actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
output, softmax_aux, rng_state = self_fused_attn_fwd(qkv, output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
bias, bias,
squeezed_mask, actual_seqlen,
seed, seed,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
...@@ -96,12 +97,12 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda ...@@ -96,12 +97,12 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda
output = checkpoint_name(output, 'context') output = checkpoint_name(output, 'context')
softmax_aux = checkpoint_name(softmax_aux, 'context') softmax_aux = checkpoint_name(softmax_aux, 'context')
rng_state = checkpoint_name(rng_state, 'context') rng_state = checkpoint_name(rng_state, 'context')
return output, (qkv, bias, softmax_aux, rng_state, output, squeezed_mask) return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen)
def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz): is_training, ctx, dz):
qkv, bias, softmax_aux, rng_state, output, squeezed_mask = ctx qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx
grad_qkv, grad_bias = self_fused_attn_bwd(qkv, grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
bias, bias,
...@@ -109,7 +110,7 @@ def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dr ...@@ -109,7 +110,7 @@ def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dr
rng_state, rng_state,
output, output,
dz, dz,
squeezed_mask, actual_seqlen,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
...@@ -159,14 +160,19 @@ def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: ...@@ -159,14 +160,19 @@ def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask:
def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type, def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training): scaling_factor, dropout_probability, is_training):
q_squeezed_mask = mask[..., 0] mask = jnp.logical_not(mask)
kv_squeezed_mask = mask[..., 0, :] q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else:
# When mask is padding + causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = cross_fused_attn_fwd(q, output, softmax_aux, rng_state = cross_fused_attn_fwd(q,
kv, kv,
bias, bias,
q_squeezed_mask, q_actual_seqlen,
kv_squeezed_mask, kv_actual_seqlen,
seed, seed,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
...@@ -174,12 +180,12 @@ def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mas ...@@ -174,12 +180,12 @@ def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mas
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return output, (q, kv, bias, softmax_aux, rng_state, output, q_squeezed_mask, kv_squeezed_mask) return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen)
def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz): is_training, ctx, dz):
q, kv, bias, softmax_aux, rng_state, output, q_squeezed_mask, kv_squeezed_mask = ctx q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
grad_q, grad_kv, grad_bias = cross_fused_attn_bwd(q, grad_q, grad_kv, grad_bias = cross_fused_attn_bwd(q,
kv, kv,
...@@ -188,8 +194,8 @@ def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, d ...@@ -188,8 +194,8 @@ def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, d
rng_state, rng_state,
output, output,
dz, dz,
q_squeezed_mask, q_actual_seqlen,
kv_squeezed_mask, kv_actual_seqlen,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
......
...@@ -64,6 +64,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer): ...@@ -64,6 +64,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
head_dim: int = 64 head_dim: int = 64
num_heads: int = 16 num_heads: int = 16
num_gqa_groups: int | None = None
dropout_rate: float = 0. dropout_rate: float = 0.
dropout_rng_name: str = 'dropout' dropout_rng_name: str = 'dropout'
layernorm_type: str = "layernorm" layernorm_type: str = "layernorm"
...@@ -80,6 +81,11 @@ class MultiHeadAttention(TransformerEngineBaseLayer): ...@@ -80,6 +81,11 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
scaled_query_init: bool = True scaled_query_init: bool = True
float32_logits: bool = False float32_logits: bool = False
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_heads
super().__post_init__()
def setup(self) -> None: def setup(self) -> None:
"""setup""" """setup"""
super().setup() super().setup()
...@@ -89,6 +95,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer): ...@@ -89,6 +95,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
dtype=self.dtype, dtype=self.dtype,
head_dim=self.head_dim, head_dim=self.head_dim,
num_heads=self.num_heads, num_heads=self.num_heads,
num_gqa_groups=self.num_gqa_groups,
dropout_rate=self.dropout_rate, dropout_rate=self.dropout_rate,
dropout_rng_name=self.dropout_rng_name, dropout_rng_name=self.dropout_rng_name,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
...@@ -131,6 +138,7 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -131,6 +138,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
hidden_size: int = 512 hidden_size: int = 512
mlp_hidden_size: int = 2048 mlp_hidden_size: int = 2048
num_attention_heads: int = 8 num_attention_heads: int = 8
num_gqa_groups: int | None = None
layernorm_type: str = 'layernorm' layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6 layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
...@@ -156,6 +164,11 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -156,6 +164,11 @@ class TransformerLayer(TransformerEngineBaseLayer):
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
def setup(self) -> None: def setup(self) -> None:
"""setup""" """setup"""
super().setup() super().setup()
...@@ -186,6 +199,7 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -186,6 +199,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
mlp_hidden_size=self.mlp_hidden_size, mlp_hidden_size=self.mlp_hidden_size,
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon, layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment