Unverified Commit 8f6c5248 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX][Common] Support GQA (#578)



* Support num_gqa_groups arguments
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add GQA support on the JAX bridge code
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix the kv stride of the arbitrary backend
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Complete rewrite fused attention tests and add GQA coverage
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Support unfused GQA
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Calculate seqlen before the primitive for the better perf
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add GQA layer tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Apply code style checks for te_jax
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Apply code style checks for tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add num_gqa_groups doc
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine the qkv_type
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Correct the variable naming
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Handle Max512 CAUSAL
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add WAR for the latest jax image
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent daad219f
......@@ -4,6 +4,9 @@
set -xe
# WAR(rewang) for the "Check failed: reduction_kind.has_value()"
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_xla_runtime_executable=true"
: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_*
......@@ -14,5 +14,7 @@ pytest -Wignore -v $TE_PATH/examples/jax/mnist
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
# WAR(rewang) for the "Check failed: reduction_kind.has_value()"
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_xla_runtime_executable=true"
pytest -Wignore -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
......@@ -3,8 +3,8 @@
# See LICENSE for license information.
"""Tests for fused attention"""
import os
from enum import Enum
from dataclasses import dataclass
from functools import partial
from math import sqrt
import jax
......@@ -13,18 +13,15 @@ import numpy as np
import pytest
from flax.linen import combine_masks
from flax.linen import dot_product_attention
from flax.linen import make_attention_mask
from flax.linen import make_causal_mask
from flax.linen.dtypes import promote_dtype
from jax import Array
from jax import value_and_grad, jit
from jax.typing import ArrayLike, DTypeLike
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine_jax import get_device_compute_capability # pylint: disable=wrong-import-order
# Type annotations
Array = jnp.ndarray
@pytest.fixture(autouse=True, scope='function')
......@@ -32,34 +29,55 @@ def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
# Calling customcalls before jax may cause CUDA uninitialize error
_ = jnp.zeros(0)
yield
for arr in jax.live_arrays():
arr.delete()
class Backend(Enum):
def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: ArrayLike,
bias: ArrayLike, mask: ArrayLike, deterministic: bool,
dropout_rate: float, dropout_rng: ArrayLike,
dtype: DTypeLike) -> Array:
"""
Fused attn backend.
Unit tests only, transformer will auto dispatch to the best backend
Similar to flax.linen.dot_product_attention but with GQA support
"""
Max512 = "0"
Arbitrary = "1"
query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
dtype = query.dtype
depth = query.shape[-1]
query = query / jnp.sqrt(depth).astype(dtype)
@pytest.fixture(name="backend", params=[Backend.Max512, Backend.Arbitrary])
def fixture_backend(request):
"""
Fixture of setting up/tearing down backend
"""
backend = request.param
os.environ["NVTE_FUSED_ATTN_BACKEND"] = backend.value
yield backend
os.environ["NVTE_FUSED_ATTN_BACKEND"] = ""
b, s_q, h_q, d = query.shape
_, _, h_kv, _ = key.shape
assert (h_q % h_kv == 0) and (h_q >= h_kv)
num_groups = h_q // h_kv
grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
# logits with shape (b, h_kv, num_groups, s_q, s_kv)
logits = jnp.einsum('...qhgd,...khd->...hgqk', grouped_query, key)
if bias is not None:
if bias.ndim != logits.ndim:
bias = bias.reshape((1, *logits.shape[1:]))
logits = logits + bias
if mask is not None:
if mask.ndim != logits.ndim:
mask = jnp.expand_dims(mask, axis=-3)
logits = jnp.where(mask, logits, jnp.finfo(dtype).min)
softmax_out = jax.nn.softmax(logits).astype(dtype)
if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate
keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
softmax_out = softmax_out * multiplier
SELF_CASES = [(32, 512, 16, 64), (32, 128, 16, 64), (4, 2048, 12, 64)]
CROSS_CASES = [(32, 128, 512, 16, 64)]
DTYPES = [jnp.bfloat16, jnp.float16]
context = jnp.einsum('...hgqk,...khd->...qhgd', softmax_out, value)
context = jnp.reshape(context, query.shape)
return context
def is_causal_mask(mask: AttnMaskType):
......@@ -69,31 +87,28 @@ def is_causal_mask(mask: AttnMaskType):
return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]
def make_decoder_mask(tokens: Array) -> Array:
def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
"""
Create padded causal mask
"""
causal_mask = make_causal_mask(tokens)
padding_mask = make_attention_mask(tokens > 0, tokens > 0)
q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape)
kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape)
causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal)
padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0)
return combine_masks(causal_mask, padding_mask)
def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
"""
Self attention with JAX native implementation
JAX native dot product attention implementation
"""
attn_mask_type = kwargs['attn_mask_type']
if is_causal_mask(attn_mask_type):
mask = make_decoder_mask(q_token)
mask = make_decoder_mask(q_token, kv_token)
else:
mask = make_attention_mask(q_token > 0, kv_token > 0)
query, key, value = jnp.split(qkv, [1, 2], axis=-3)
query = jnp.squeeze(query)
key = jnp.squeeze(key)
value = jnp.squeeze(value)
output = dot_product_attention(query,
output = general_dot_product_attention(query,
key,
value,
bias=bias,
......@@ -102,485 +117,259 @@ def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng,
dtype=jnp.float32)
return output.astype(qkv.dtype)
return output.astype(query.dtype)
def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
"""
Cross attention with JAX native implementation
TE customcall dot product attention implementation
"""
assert q.dtype == kv.dtype
attn_mask_type = kwargs['attn_mask_type']
if is_causal_mask(attn_mask_type):
raise NotImplementedError
mask = make_attention_mask(q_token > 0, kv_token > 0)
query = q
key, value = jnp.split(kv, [1], axis=-3)
key = jnp.squeeze(key)
value = jnp.squeeze(value)
output = dot_product_attention(query,
key,
value,
bias=None,
mask=mask,
deterministic=not kwargs['is_training'],
dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng,
dtype=jnp.float32)
return output.astype(q.dtype)
def customcall_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
"""
Self fused attention
"""
attn_mask_type = kwargs['attn_mask_type']
if is_causal_mask(attn_mask_type):
mask = make_decoder_mask(q_token)
mask = make_decoder_mask(q_token, kv_token)
else:
mask = make_attention_mask(q_token > 0, kv_token > 0)
# mask invert
mask = (mask == 0)
return self_fused_attn(qkv, bias, mask, dropout_rng, **kwargs)
def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
mask = jnp.logical_not(mask)
qkv_layout = kwargs.pop('qkv_layout')
match qkv_layout:
case QKVLayout.BS3HD:
query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
qkv = jnp.concatenate((query, key, value), axis=-3)
return self_fused_attn(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
case QKVLayout.BSHD_BS2HD:
key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
kv = jnp.concatenate((key, value), axis=-3)
return cross_fused_attn(query, kv, bias, mask, dropout_rng,
**kwargs).astype(query.dtype)
@dataclass
class FusedAttnRunner:
"""
Cross fused attention
Fused attention runner
"""
assert q.dtype == kv.dtype
attn_mask_type = kwargs['attn_mask_type']
if is_causal_mask(attn_mask_type):
raise NotImplementedError
mask = make_attention_mask(q_token > 0, kv_token > 0)
# mask invert
mask = (mask == 0)
return cross_fused_attn(q, kv, None, mask, dropout_rng, **kwargs)
batch_size: int
max_seqlen_q: int
max_seqlen_kv: int
num_heads_q: int
num_heads_kv: int
head_dim: int
attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType
dropout_prob: float
dtype: DTypeLike
is_training: bool
qkv_layout: QKVLayout
def _check_configs(self):
if self.qkv_layout == QKVLayout.BS3HD and self.num_heads_q != self.num_heads_kv:
pytest.skip("BS3HD layout requires num_heads_q and num_heads_kv to be equal.")
if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")
if not is_fused_attn_kernel_available(
self.dtype, self.dtype, self.qkv_layout, self.attn_bias_type, self.attn_mask_type,
self.dropout_prob, self.num_heads_q, self.num_heads_kv, self.max_seqlen_q,
self.max_seqlen_kv, self.head_dim):
pytest.skip("Unsupported inputs combination or device compute capability.")
def _setup_inputs(self):
self._check_configs()
key = jax.random.PRNGKey(0)
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type', [
AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK,
AttnMaskType.PADDING_CAUSAL_MASK
])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
class TestSelfFusedAttn():
"""Tests for transformer_engine.jax.fused_attn.self_fused_attn"""
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim)
bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
@staticmethod
def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probability, dtype,
num_heads_q, num_heads_kv, head_dim):
self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1)
self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1)
self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1)
assert isinstance(backend, Backend)
with_bias = self.attn_bias_type != AttnBiasType.NO_BIAS
self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1) if with_bias else None
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
attn_mask_type, dropout_probability,
num_heads_q, num_heads_kv,
s, s, head_dim):
pytest.skip("Unsupported inputs combination or device compute capability.")
def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
dropout_probability, dtype, is_training):
"""Setup the test inputs"""
self.__class__._check_inputs(s,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
num_heads_q=h,
num_heads_kv=h,
head_dim=d)
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pad_ratio = 0.0
else:
pad_ratio = 0.3
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
def gen_valid(bs, max_seqlen, pad_ratio):
pad_len = int(max_seqlen * pad_ratio)
valid_len = max_seqlen - pad_len
tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
return valid_len, tokens
qkv_shape = (b, s, 3, h, d)
bias_shape = (1, h, s, s)
self.valid_len_q, self.token_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio)
self.valid_len_kv, self.token_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
pad_len = int(s * pad_ratio)
self.valid_len = s - pad_len
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1. / sqrt(self.head_dim)
min_val, max_val = -1, 1
self.qkv = jax.random.uniform(subkeys[0], qkv_shape, dtype, min_val, max_val)
def test_forward(self):
"""
Test forward without JIT
"""
self._setup_inputs()
with_bias = attn_bias_type != AttnBiasType.NO_BIAS
self.bias = jax.random.uniform(subkeys[1], bias_shape, dtype, min_val,
max_val) if with_bias else None
args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
kwargs = {
'attn_bias_type': self.attn_bias_type,
'attn_mask_type': self.attn_mask_type,
'scaling_factor': self.scaling_factor,
'dropout_probability': self.dropout_prob,
'is_training': self.is_training,
'qkv_layout': self.qkv_layout,
}
self.q_token = jnp.concatenate((jnp.ones((b, self.valid_len)), jnp.zeros((b, pad_len))),
axis=-1)
self.kv_token = self.q_token
# Convert the outputs to float32 for the elementwise comparison
primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32)
reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32)
self.scaling_factor = 1. / sqrt(d)
self.dropout_probability = dropout_probability
self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None
self.attn_bias_type = attn_bias_type
self.attn_mask_type = attn_mask_type
self.is_training = is_training
primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1)
reference_valid, _ = jnp.split(reference_out, (self.valid_len_q,), axis=1)
def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, dropout_probability,
dtype, is_training):
"""
Test forward without using JIT
"""
self._set_inputs(b,
s,
h,
d,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training)
primitive_out = customcall_self_fused_attn(self.qkv,
self.bias,
self.q_token,
self.kv_token,
self.dropout_rng,
attn_bias_type=self.attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability,
is_training=self.is_training)
reference_out = jax_self_attn(self.qkv,
self.bias,
self.q_token,
self.kv_token,
self.dropout_rng,
attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability,
is_training=self.is_training)
ref_valid, _ = jnp.split(reference_out, (self.valid_len,), axis=1)
pri_valid, pri_invalid = jnp.split(primitive_out, (self.valid_len,), axis=1)
# Dropout can't get the bitmatch result, skip the elementwise comparison
if is_training and dropout_probability > 0.:
# Skip elementwise comparison when dropout enabled
if self.is_training and self.dropout_prob > 0.:
return
np.testing.assert_allclose(jnp.asarray(pri_valid, np.float32),
jnp.asarray(ref_valid, np.float32),
rtol=1e-4,
atol=1e-2)
np.testing.assert_allclose(jnp.asarray(pri_invalid, jnp.float32),
jnp.zeros_like(pri_invalid, jnp.float32))
np.testing.assert_allclose(primitive_valid, reference_valid, atol=1e-2, rtol=1e-4)
np.testing.assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid))
def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend,
dropout_probability, dtype, is_training):
def test_backward(self):
"""
Test forward, backward, and autodiff by jax.value_and_grad
Test value_and_grad with JIT, which includes both forward and backward
"""
if not is_training:
pytest.skip(f"Backward doesn't support {is_training=}")
self._set_inputs(b,
s,
h,
d,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training)
def grad_func(fused_attn_func, *args, **kwargs):
if not self.is_training:
pytest.skip("Backward doesn't support inference")
self._setup_inputs()
def grad_func(func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident
gradient_multiplier = 1000 if dtype == jnp.bfloat16 else 10000
if is_causal_mask(attn_mask_type):
gradient_multiplier = gradient_multiplier / 10
gradient_multiplier = self.valid_len_q * self.num_heads_q
if is_causal_mask(self.attn_mask_type):
gradient_multiplier /= 10
# Keep only valid result for the gradient
# fused_attn output has shape (b, s, h, d)
valid_fused_attn_ret, _ = jnp.split(fused_attn_func(*args, **kwargs), (self.valid_len,),
axis=1)
return (jnp.mean(valid_fused_attn_ret, dtype=jnp.float32) *
gradient_multiplier).astype(dtype)
ret_valid, _ = jnp.split(func(*args, **kwargs), (self.valid_len_q,), axis=1)
return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype)
args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
kwargs = {
'attn_bias_type': self.attn_bias_type,
'attn_mask_type': attn_mask_type,
'attn_mask_type': self.attn_mask_type,
'scaling_factor': self.scaling_factor,
'dropout_probability': self.dropout_probability,
'is_training': self.is_training
'dropout_probability': self.dropout_prob,
'is_training': self.is_training,
'qkv_layout': self.qkv_layout,
}
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
value_and_grad(
lambda qkv, bias, q_token, kv_token, dropout_rng: grad_func(
customcall_self_fused_attn, qkv, bias, q_token, kv_token, dropout_rng, **kwargs
), (0, 1)))
lambda q, k, v, bias, *args: grad_func(customcall_fused_dpa, q, k, v, bias, *args,
**kwargs), (0, 1, 2, 3)))
jitted_reference = jit(
value_and_grad(
lambda qkv, bias, q_token, kv_token, dropout_rng: grad_func(
jax_self_attn, qkv, bias, q_token, kv_token, dropout_rng, **kwargs), (0, 1)))
primitive_out, (primitive_dqkv,
primitive_dbias) = jitted_primitive(self.qkv, self.bias, self.q_token,
self.kv_token, self.dropout_rng)
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
(0, 1, 2, 3)))
reference_out, (reference_dqkv,
reference_dbias) = jitted_reference(self.qkv, self.bias, self.q_token,
self.kv_token, self.dropout_rng)
primitive_out, primitive_dgrad = jitted_primitive(*args)
reference_out, reference_dgrad = jitted_reference(*args)
# Dropout can't get the bitmatch result, skip the elementwise comparison
if dropout_probability > 0.:
# Skip elementwise comparison when dropout enabled
if self.dropout_prob > 0.:
return
np.testing.assert_allclose(jnp.asarray(primitive_out, np.float32),
jnp.asarray(reference_out, np.float32),
rtol=1e-4,
atol=1e-5)
valid_primitive_dqkv, invalid_primitive_dqkv = \
jnp.split(primitive_dqkv.astype(jnp.float32), (self.valid_len,), axis=1)
valid_reference_dqkv, invalid_reference_dqkv = \
jnp.split(reference_dqkv.astype(jnp.float32), (self.valid_len,), axis=1)
np.testing.assert_allclose(primitive_out.astype(jnp.float32),
reference_out.astype(jnp.float32),
atol=1e-5,
rtol=1e-3)
valid_primitive_dq, valid_primitive_dk, valid_primitive_dv = \
jnp.split(valid_primitive_dqkv, 3, axis=2)
valid_reference_dq, valid_reference_dk, valid_reference_dv = \
jnp.split(valid_reference_dqkv, 3, axis=2)
# Convert the outputs to float32 for the elementwise comparison
primitive_dq, primitive_dk, primitive_dv, primitive_dbias = map(
jnp.float32, primitive_dgrad)
reference_dq, reference_dk, reference_dv, reference_dbias = map(
jnp.float32, reference_dgrad)
np.testing.assert_allclose(valid_primitive_dq, valid_reference_dq, rtol=1e-4, atol=1e-5)
np.testing.assert_allclose(valid_primitive_dk, valid_reference_dk, rtol=1e-4, atol=1e-5)
np.testing.assert_allclose(valid_primitive_dv, valid_reference_dv, rtol=1e-4, atol=1e-5)
def check_dqkv(primitive, reference, valid_len):
primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1)
reference_valid, reference_invalid = jnp.split(reference, (valid_len,), axis=1)
assert jnp.allclose(invalid_primitive_dqkv, invalid_reference_dqkv)
np.testing.assert_allclose(primitive_valid, reference_valid, atol=1e-4, rtol=1e-3)
assert jnp.allclose(primitive_invalid, reference_invalid)
assert jnp.allclose(primitive_invalid, jnp.zeros_like(primitive_invalid))
# Padded part should be 0s
assert jnp.allclose(invalid_primitive_dqkv, jnp.zeros_like(invalid_primitive_dqkv))
check_dqkv(primitive_dq, reference_dq, self.valid_len_q)
check_dqkv(primitive_dk, reference_dk, self.valid_len_kv)
check_dqkv(primitive_dv, reference_dv, self.valid_len_kv)
if self.attn_bias_type != AttnBiasType.NO_BIAS:
# dbias valid part
np.testing.assert_allclose(
jnp.asarray(primitive_dbias[:, :, :self.valid_len, :self.valid_len], np.float32),
jnp.asarray(reference_dbias[:, :, :self.valid_len, :self.valid_len], np.float32),
rtol=1e-4,
atol=3e-5)
np.testing.assert_allclose(primitive_dbias[..., :self.valid_len_q, :self.valid_len_kv],
reference_dbias[..., :self.valid_len_q, :self.valid_len_kv],
atol=3e-5,
rtol=1e-4)
# dbias padded part
np.testing.assert_allclose(
jnp.asarray(primitive_dbias[:, :, self.valid_len:, self.valid_len:], np.float32),
jnp.asarray(reference_dbias[:, :, self.valid_len:, self.valid_len:], np.float32))
np.testing.assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
reference_dbias[..., self.valid_len_q:, self.valid_len_kv:])
assert jnp.allclose(
primitive_dbias[:, :, self.valid_len:, self.valid_len:],
jnp.zeros_like(primitive_dbias[:, :, self.valid_len:, self.valid_len:]))
@pytest.mark.skipif(get_device_compute_capability(0) not in [80, 90],
reason="Fused attention kernel is not supported.")
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES)
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', [0.3])
class TestCrossFusedAttn():
"""Tests for transformer_engine.jax.fused_attn.cross_fused_attn"""
def _set_inputs(self, b, s_q, s_kv, h, d, *, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
q_shape = (b, s_q, h, d)
kv_shape = (b, s_kv, 2, h, d)
q_pad_len = int(s_q * pad_ratio)
kv_pad_len = int(s_kv * pad_ratio)
self.q_valid_len = s_q - q_pad_len
self.kv_valid_len = s_kv - kv_pad_len
min_val, max_val = -1, 1
self.q = jax.random.uniform(subkeys[0], q_shape, dtype, min_val, max_val)
self.kv = jax.random.uniform(subkeys[1], kv_shape, dtype, min_val, max_val)
self.q_token = jnp.concatenate((jnp.ones((b, self.q_valid_len)), jnp.zeros((b, q_pad_len))),
axis=-1)
self.kv_token = jnp.concatenate((jnp.ones((b, self.kv_valid_len)), jnp.zeros(
(b, kv_pad_len))),
axis=-1)
self.scaling_factor = 1. / sqrt(d)
self.dropout_probability = dropout_probability
self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None
self.attn_bias_type = AttnBiasType.NO_BIAS
self.attn_mask_type = attn_mask_type
self.is_training = is_training
def test_forward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
"""
Test forward without using JIT
"""
self._set_inputs(b,
s_q,
s_kv,
h,
d,
attn_mask_type=attn_mask_type,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
primitive_out = customcall_cross_fused_attn(self.q,
self.kv,
self.q_token,
self.kv_token,
self.dropout_rng,
attn_bias_type=self.attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability,
is_training=self.is_training)
reference_out = jax_cross_attn(self.q,
self.kv,
self.q_token,
self.kv_token,
self.dropout_rng,
attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability,
is_training=self.is_training)
# Dropout can't get the bitmatch result, skip the elementwise comparison
if is_training and dropout_probability > 0.:
return
ref_valid, _ = jnp.split(reference_out, (self.q_valid_len,), axis=1)
pri_valid, pri_invalid = jnp.split(primitive_out, (self.q_valid_len,), axis=1)
np.testing.assert_allclose(jnp.asarray(pri_valid, np.float32),
jnp.asarray(ref_valid, np.float32),
rtol=1e-4,
atol=2e-3)
primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:]))
np.testing.assert_allclose(jnp.asarray(pri_invalid, jnp.float32),
jnp.zeros_like(pri_invalid, jnp.float32))
def test_forward_backward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
@pytest.mark.parametrize('attn_bias_type', [
pytest.param(AttnBiasType.NO_BIAS, id='NO_BIAS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, id='POST_SCALE_BIAS'),
])
@pytest.mark.parametrize('attn_mask_type', [
pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'),
pytest.param(AttnMaskType.PADDING_MASK, id='PADDING'),
pytest.param(AttnMaskType.CAUSAL_MASK, id='CAUSAL'),
pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id='PADDING_CAUSAL'),
])
@pytest.mark.parametrize('qkv_layout', [
pytest.param(QKVLayout.BS3HD, id='qkvpacked'),
pytest.param(QKVLayout.BSHD_BS2HD, id='kvpacked'),
])
@pytest.mark.parametrize('dropout_prob', [0., 0.1])
@pytest.mark.parametrize('is_training',
[pytest.param(True, id='training'),
pytest.param(False, id='inference')])
@pytest.mark.parametrize(
'dtype', [pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16")])
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d',
[(32, 128, 128, 16, 16, 64), (4, 2048, 2048, 12, 12, 64),
pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-cross'),
pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA')])
class TestFusedAttn:
"""
Test forward, backward, and autodiff by jax.value_and_grad
Fused attention tester
"""
if not is_training:
pytest.skip(f"Backward doesn't support {is_training=}")
self._set_inputs(b,
s_q,
s_kv,
h,
d,
attn_mask_type=attn_mask_type,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
def grad_func(fused_attn_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident
gradient_multiplier = 1e4
# Keep only valid result for the gradient
# fused_attn output has shape (b, s_q, h, d)
valid_fused_attn_ret, _ = jnp.split(fused_attn_func(*args, **kwargs),
(self.q_valid_len,),
axis=1)
return (jnp.mean(valid_fused_attn_ret, dtype=jnp.float32) *
gradient_multiplier).astype(dtype)
kwargs = {
'attn_bias_type': self.attn_bias_type,
'attn_mask_type': attn_mask_type,
'scaling_factor': self.scaling_factor,
'dropout_probability': self.dropout_probability,
'is_training': self.is_training
}
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
value_and_grad(
lambda q, kv, q_token, kv_token, dropout_rng: grad_func(
customcall_cross_fused_attn, q, kv, q_token, kv_token, dropout_rng, **kwargs),
(0, 1)))
jitted_reference = jit(
value_and_grad(
lambda q, kv, q_token, kv_token, dropout_rng: grad_func(
jax_cross_attn, q, kv, q_token, kv_token, dropout_rng, **kwargs), (0, 1)))
primitive_out, (primitive_dq,
primitive_dkv) = jitted_primitive(self.q, self.kv, self.q_token,
self.kv_token, self.dropout_rng)
reference_out, (reference_dq,
reference_dkv) = jitted_reference(self.q, self.kv, self.q_token,
self.kv_token, self.dropout_rng)
# Dropout can't get the bitmatch result, skip the elementwise comparison
if dropout_probability > 0.:
return
@staticmethod
def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
dtype, is_training, qkv_layout):
"""
Test forward with parameterized configs
"""
runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, is_training, qkv_layout)
runner.test_forward()
np.testing.assert_allclose(jnp.asarray(primitive_out, np.float32),
jnp.asarray(reference_out, np.float32),
rtol=1e-4,
atol=1e-5)
valid_primitive_dq, invalid_primitive_dq = jnp.split(primitive_dq, (self.q_valid_len,),
axis=1)
valid_reference_dq, invalid_reference_dq = jnp.split(reference_dq, (self.q_valid_len,),
axis=1)
valid_primitive_dkv, invalid_primitive_dkv = jnp.split(primitive_dkv, (self.kv_valid_len,),
axis=1)
valid_reference_dkv, invalid_reference_dkv = jnp.split(reference_dkv, (self.kv_valid_len,),
axis=1)
# dQ
np.testing.assert_allclose(jnp.asarray(valid_primitive_dq, np.float32),
jnp.asarray(valid_reference_dq, np.float32),
rtol=1e-4,
atol=1e-5)
# dK
np.testing.assert_allclose(jnp.asarray(valid_primitive_dkv[:, :, 0], np.float32),
jnp.asarray(valid_reference_dkv[:, :, 0], np.float32),
rtol=1e-4,
atol=1e-5)
# dV
np.testing.assert_allclose(jnp.asarray(valid_primitive_dkv[:, :, 1], np.float32),
jnp.asarray(valid_reference_dkv[:, :, 1], np.float32),
rtol=1e-4,
atol=1e-5)
assert jnp.allclose(invalid_primitive_dq, invalid_reference_dq)
assert jnp.allclose(invalid_primitive_dkv, invalid_reference_dkv)
# Padded part should be 0s
assert jnp.allclose(invalid_primitive_dq, jnp.zeros_like(invalid_primitive_dq))
assert jnp.allclose(invalid_primitive_dkv, jnp.zeros_like(invalid_primitive_dkv))
@staticmethod
def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
dtype, is_training, qkv_layout):
"""
Test backward with parameterized configs
"""
runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, is_training, qkv_layout)
runner.test_backward()
......@@ -9,13 +9,14 @@ import jax
import jax.numpy as jnp
import pytest
from transformer_engine.common.recipe import Format
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from utils import assert_allclose
from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer
from transformer_engine.common.recipe import Format
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
is_fp8_supported, reason = is_fp8_available()
......@@ -85,8 +86,13 @@ _KEY_OF_LAYERNORM_TYPE = 'layernorm_type'
_KEY_OF_ZERO_CENTERED_GAMMA = 'zero_centered_gamma'
_KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence'
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
_KEY_OF_NUM_HEADS = 'num_attention_heads'
_KEY_OF_NUM_GQA_GROUPS = 'num_gqa_groups'
BASE_ATTRS = {_KEY_OF_TRANSPOSE_BS: True}
BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_NUM_HEADS: 8,
}
ATTRS = [{
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
......@@ -129,6 +135,9 @@ ATTRS = [{
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4
}]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
......@@ -137,21 +146,13 @@ ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
class TestEncoderLayer:
@staticmethod
def sync_params(ref, target, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
def sync_params(ref, target):
unfreeze_target = flax.core.unfreeze(target)
if fuse_qkv:
unfreeze_target['attention']['qkv']['kernel'] = \
jnp.reshape(ref['attention']['qkv']['kernel'],
unfreeze_target['attention']['qkv']['kernel'].shape)
else:
unfreeze_target['attention']['query']['kernel'] = \
ref['attention']['query']['kernel']
unfreeze_target['attention']['key']['kernel'] = \
ref['attention']['key']['kernel']
unfreeze_target['attention']['value']['kernel'] = \
ref['attention']['value']['kernel']
unfreeze_attn_scope = unfreeze_target['attention']
ref_attn_scope = ref['attention']
for key in ref_attn_scope.keys():
unfreeze_attn_scope[key]['kernel'] = \
ref_attn_scope[key]['kernel'].reshape(unfreeze_attn_scope[key]['kernel'].shape)
unfreeze_target['mlp']['wi_kernel'] = \
jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
unfreeze_target['mlp']['wo_kernel'] = \
......@@ -196,7 +197,7 @@ class TestEncoderLayer:
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks)
ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params, attrs)
ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params)
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)
......@@ -242,7 +243,7 @@ class TestEncoderLayer:
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks)
ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params, attrs)
ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params)
if FP8Helper.is_fp8_enabled():
for _ in range(4):
......@@ -266,7 +267,10 @@ class TestEncoderLayer:
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
def reorganize_test_wgrad(test_wgrad, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
num_heads = attrs.get(_KEY_OF_NUM_HEADS)
num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) and \
num_heads == num_gqa_groups
attn_name = 'attention'
unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad)
......@@ -280,10 +284,12 @@ class TestEncoderLayer:
unfreeze_test_wgrad['pre_attention_layer_norm']['ln_bias'] = \
unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
if fuse_qkv:
unfreeze_test_wgrad[attn_name]['qkv']['kernel'] = \
jnp.reshape(unfreeze_test_wgrad[attn_name]['qkv']['kernel'],
(unfreeze_test_wgrad[attn_name]['qkv']['kernel'].shape[0], -1))
for key in unfreeze_test_wgrad[attn_name].keys():
unfreeze_test_wgrad[attn_name][key]['kernel'] = \
jnp.reshape(unfreeze_test_wgrad[attn_name][key]['kernel'],
(unfreeze_test_wgrad[attn_name][key]['kernel'].shape[0], -1))
unfreeze_test_wgrad['pre_mlp_layer_norm'] = {}
unfreeze_test_wgrad['pre_mlp_layer_norm']['scale'] = \
unfreeze_test_wgrad['mlp']['scale']
......@@ -348,26 +354,14 @@ class TestEncoderLayer:
class TestDecoderLayer:
@staticmethod
def sync_params(ref, target, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
def sync_params(ref, target):
unfreeze_target = flax.core.unfreeze(target)
if fuse_qkv:
unfreeze_target['self_attention']['qkv']['kernel'] = \
jnp.reshape(ref['self_attention']['qkv']['kernel'],
unfreeze_target['self_attention']['qkv']['kernel'].shape)
unfreeze_target['encoder_decoder_attention']['kv']['kernel'] = \
jnp.reshape(ref['encoder_decoder_attention']['kv']['kernel'],
unfreeze_target['encoder_decoder_attention']['kv']['kernel'].shape)
else:
unfreeze_target['self_attention']['query']['kernel'] = \
ref['self_attention']['query']['kernel']
unfreeze_target['self_attention']['key']['kernel'] = \
ref['self_attention']['key']['kernel']
unfreeze_target['self_attention']['value']['kernel'] = \
ref['self_attention']['value']['kernel']
unfreeze_target['encoder_decoder_attention']['query']['kernel'] = \
ref['encoder_decoder_attention']['query']['kernel']
for scope in ['self_attention', 'encoder_decoder_attention']:
unfreeze_scope = unfreeze_target[scope]
ref_scope = ref[scope]
for key in unfreeze_scope.keys():
unfreeze_scope[key]['kernel'] = \
ref_scope[key]['kernel'].reshape(unfreeze_scope[key]['kernel'].shape)
unfreeze_target['mlp']['wi_kernel'] = \
jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
unfreeze_target['mlp']['wo_kernel'] = \
......@@ -412,7 +406,7 @@ class TestDecoderLayer:
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks)
ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params, attrs)
ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params)
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)
......@@ -459,7 +453,7 @@ class TestDecoderLayer:
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks)
ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params, attrs)
ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params)
if FP8Helper.is_fp8_enabled():
for _ in range(4):
......@@ -483,11 +477,14 @@ class TestDecoderLayer:
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
def reorganize_test_wgrad(test_wgrad, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
attn_name = 'self_attention'
num_heads = attrs.get(_KEY_OF_NUM_HEADS)
num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) and \
num_heads == num_gqa_groups
unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad)
if "output_layernorm" not in attrs:
attn_name = 'self_attention'
unfreeze_test_wgrad['pre_self_attention_layer_norm'] = {}
pre_attn_layer_key = 'qkv' if fuse_qkv else 'query'
unfreeze_test_wgrad['pre_self_attention_layer_norm']['scale'] = \
......@@ -498,14 +495,11 @@ class TestDecoderLayer:
unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
if fuse_qkv:
unfreeze_test_wgrad[attn_name]['qkv']['kernel'] = \
jnp.reshape(unfreeze_test_wgrad[attn_name]['qkv']['kernel'],
(unfreeze_test_wgrad[attn_name]['qkv']['kernel'].shape[0], -1))
attn_name = 'encoder_decoder_attention'
unfreeze_test_wgrad[attn_name]['kv']['kernel'] = \
jnp.reshape(unfreeze_test_wgrad[attn_name]['kv']['kernel'],
(unfreeze_test_wgrad[attn_name]['kv']['kernel'].shape[0], -1))
for scope in ['self_attention', 'encoder_decoder_attention']:
for key in unfreeze_test_wgrad[scope].keys():
unfreeze_test_wgrad[scope][key]['kernel'] = \
jnp.reshape(unfreeze_test_wgrad[scope][key]['kernel'],
(unfreeze_test_wgrad[scope][key]['kernel'].shape[0], -1))
unfreeze_test_wgrad['pre_cross_attention_layer_norm'] = {}
unfreeze_test_wgrad['pre_cross_attention_layer_norm']['scale'] = \
......
......@@ -12,6 +12,8 @@ from praxis import pax_fiddle
from praxis.base_layer import WeightInit, DEFAULT_INIT_MUTABLE_LIST
import pytest
from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax import fp8_autocast, update_fp8_metas, update_collections
from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral
......@@ -23,12 +25,12 @@ from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLaye
from transformer_engine.jax.flax.module import Softmax
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from transformer_engine.jax.praxis import LayerNorm
from transformer_engine.jax.praxis import FusedSoftmax, LayerNorm
from transformer_engine.jax.praxis import FusedSoftmax
from transformer_engine.jax.praxis import LayerNormLinear, LayerNormMLP, Linear
from transformer_engine.jax.praxis import MultiHeadAttention, RelativePositionBiases
from transformer_engine.jax.praxis import TransformerEngineBaseLayer, TransformerLayer, TransformerLayerType
from transformer_engine.jax.praxis import TransformerEngineBaseLayer
from transformer_engine.jax.praxis import TransformerLayer, TransformerLayerType
from transformer_engine.jax.softmax import SoftmaxType
from utils import assert_allclose
is_fp8_supported, reason = is_fp8_available()
......@@ -674,6 +676,8 @@ class MultiHeadAttnAttr:
LN_TYPE = 'layernorm_type'
ATTN_MASK_TYPE = 'attn_mask_type'
ZERO_CEN = 'zero_centered_gamma'
NUM_ATTN_HEADS = 'num_attention_heads'
NUM_GQA_GROUPS = 'num_gqa_groups'
ATTRS = [{
USE_BIAS: True,
LN_TYPE: 'layernorm',
......@@ -704,6 +708,13 @@ class MultiHeadAttnAttr:
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ATTN_MASK_TYPE: 'causal'
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal'
}]
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utility for the TE layer tests"""
import functools
import math
......@@ -28,6 +29,9 @@ Initializer = Callable[[PRNGKey, Shape, DType], Array]
def is_devices_enough(required):
"""
Check if the available GPUs is enough
"""
return len(jax.devices()) >= required
......@@ -121,9 +125,9 @@ def dot_product_attention(query: Array,
query: queries for calculating attention with shape of `[batch, q_length,
num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of `[batch, kv_length,
num_heads, qk_depth_per_head]`.
num_gqa_groups, qk_depth_per_head]`.
value: values to be used in attention with shape of `[batch, kv_length,
num_heads, v_depth_per_head]`.
num_gqa_groups, v_depth_per_head]`.
bias: bias for the attention weights. This should be broadcastable to the
shape `[batch, num_heads, q_length, kv_length]` This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
......@@ -141,21 +145,31 @@ def dot_product_attention(query: Array,
batch_dim = 1 if transpose_batch_sequence else 0
assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
'q, k, v batch dims must match.')
assert query.shape[-2] == key.shape[-2] == value.shape[-2], ('q, k, v num_heads must match.')
sequence_dim = 0 if transpose_batch_sequence else 1
assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.'
# Casting logits and softmax computation for float32 for model stability.
if float32_logits:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
# `attn_weights`: [batch, num_heads, q_length, kv_length]
# `attn_weights`: [batch, num_heads, groups, q_length, kv_length]
h_q, h_kv = query.shape[-2], key.shape[-2]
assert (h_q % h_kv == 0) and (h_q >= h_kv)
group_size = h_q // h_kv
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
if transpose_batch_sequence:
attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key)
attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
# reshape back to normal DPA shape for bias/softmax/dropout
b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
# Apply attention bias: masking, dropout, proximity bias, etc.
if bias is not None:
......@@ -174,11 +188,13 @@ def dot_product_attention(query: Array,
multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
attn_weights = attn_weights * multiplier
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
# Take the linear combination of `value`.
if transpose_batch_sequence:
return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value)
return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
class DenseGeneral(nn.Module):
......@@ -235,7 +251,8 @@ class DenseGeneral(nn.Module):
if self.use_bias:
bias = nn_partitioning.param_with_axes('bias',
self.bias_init, (self.features,),
self.bias_init,
self.features,
self.dtype,
axes=self.bias_axes)
else:
......@@ -332,6 +349,7 @@ class MultiHeadAttention(nn.Module):
Attributes:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
num_gqa_groups: number of kv attention heads
head_dim: dimension of each head.
dtype: the dtype of the computation.
dropout_rate: dropout rate
......@@ -340,9 +358,10 @@ class MultiHeadAttention(nn.Module):
numerical issues with bfloat16.
"""
num_heads: int
head_dim: int
transpose_batch_sequence: bool
num_heads: int = 8
num_gqa_groups: int | None = None
head_dim: int = 64
transpose_batch_sequence: bool = True
dtype: DType = jnp.float32
dropout_rate: float = 0.
kernel_init: Initializer = None
......@@ -354,6 +373,8 @@ class MultiHeadAttention(nn.Module):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact
......@@ -393,18 +414,24 @@ class MultiHeadAttention(nn.Module):
Returns:
output of shape `[batch, length, q_features]`.
"""
projection = functools.partial(DenseGeneral,
q_projection = functools.partial(DenseGeneral,
axis=-1,
features=self.num_heads * self.head_dim,
kernel_axes=('embed', 'joined_kv'),
dtype=self.dtype)
kv_projection = functools.partial(DenseGeneral,
axis=-1,
features=self.num_gqa_groups * self.head_dim,
kernel_axes=('embed', 'joined_kv'),
dtype=self.dtype)
# NOTE: T5 does not explicitly rescale the attention logits by
# 1/sqrt(depth_kq)! This is folded into the initializers of the
# linear transformations, which is equivalent under Adafactor
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
query_init = lambda *args: self.kernel_init(*args) / ( # pylint: disable=unnecessary-lambda-assignment
depth_scaling if self.scaled_query_init else 1.0)
query_init = lambda *args: self.kernel_init(*args) / (depth_scaling
if self.scaled_query_init else 1.0)
# Project inputs_q to multi-headed q/k/v
# dimensions are then [batch, length, num_heads, head_dim]
......@@ -417,13 +444,17 @@ class MultiHeadAttention(nn.Module):
v_shape = (shape[0], shape[1] // 3)
q_kernel = query_init(key, q_shape, dtype)
k_kernel = self.kernel_init(key, k_shape, dtype) # pylint: disable=too-many-function-args
v_kernel = self.kernel_init(key, v_shape, dtype) # pylint: disable=too-many-function-args
k_kernel = self.kernel_init(key, k_shape, dtype)
v_kernel = self.kernel_init(key, v_shape, dtype)
return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype)
is_self_attn = (inputs_q is inputs_kv)
is_gqa = (self.num_heads != self.num_gqa_groups)
is_qkvpack = (is_self_attn and not is_gqa)
if self.fuse_qkv:
if inputs_q is inputs_kv:
if is_qkvpack:
qkv_proj = DenseGeneral(axis=-1,
features=self.num_heads * self.head_dim * 3,
kernel_axes=('embed', 'joined_kv'),
......@@ -436,24 +467,24 @@ class MultiHeadAttention(nn.Module):
if self.scale_attn_logits:
query = query / depth_scaling
else:
query = projection(kernel_init=query_init, name='query')( \
query = q_projection(kernel_init=query_init, name='query')( \
(inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q)
kv_proj = DenseGeneral(axis=-1,
features=self.num_heads * self.head_dim * 2,
features=self.num_gqa_groups * self.head_dim * 2,
kernel_axes=('embed', 'joined_kv'),
kernel_init=self.kernel_init,
name='kv',
dtype=self.dtype)(inputs_kv)
key, value = jnp.split(kv_proj, [self.num_heads * self.head_dim], axis=-1)
key, value = jnp.split(kv_proj, [self.num_gqa_groups * self.head_dim], axis=-1)
else:
query = projection(kernel_init=query_init, name='query')( \
query = q_projection(kernel_init=query_init, name='query')( \
(inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q)
key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim))
key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim))
value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim))
query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
if self.transpose_batch_sequence:
query = nn_partitioning.with_sharding_constraint(query,
......@@ -476,7 +507,7 @@ class MultiHeadAttention(nn.Module):
# fusion optimization. This also enables the "scatter via one-hot
# broadcast" trick, which means we do a one-hot broadcast instead of a
# scatter/gather operations, resulting in a 3-4x speedup in practice.
swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3]) # pylint: disable=unnecessary-lambda-assignment
swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3])
cached_key = self.variable('cache', 'cached_key', jnp.zeros, swap_dims(key.shape),
key.dtype)
cached_value = self.variable('cache', 'cached_value', jnp.zeros, swap_dims(value.shape),
......@@ -755,7 +786,8 @@ class RelativePositionBiases(nn.Module):
class EncoderLayer(nn.Module):
"""Transformer encoder layer."""
relative_embedding: nn.Module = None
num_heads: int = 8
num_attention_heads: int = 8
num_gqa_groups: int | None = None
head_dim: int = 64
dropout_rate: float = 0.1
transpose_batch_sequence: bool = True
......@@ -773,6 +805,11 @@ class EncoderLayer(nn.Module):
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact
def __call__(self, inputs, encoder_mask=None, deterministic=False):
# Relative position embedding as attention biases.
......@@ -782,7 +819,7 @@ class EncoderLayer(nn.Module):
if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32,
max_distance=128,
num_heads=self.num_heads,
num_heads=self.num_attention_heads,
dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling(
1.0, 'fan_avg', 'uniform'),
......@@ -807,7 +844,8 @@ class EncoderLayer(nn.Module):
x = inputs
# [batch, length, emb_dim] -> [batch, length, emb_dim]
x = MultiHeadAttention(num_heads=self.num_heads,
x = MultiHeadAttention(num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
......@@ -868,7 +906,8 @@ class EncoderLayer(nn.Module):
class DecoderLayer(nn.Module):
"""Transformer decoder layer that attends to the encoder."""
relative_embedding: nn.Module = None
num_heads: int = 8
num_attention_heads: int = 8
num_gqa_groups: int | None = None
head_dim: int = 64
dropout_rate: float = 0.1
transpose_batch_sequence: bool = True
......@@ -886,6 +925,11 @@ class DecoderLayer(nn.Module):
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact
def __call__(self,
inputs,
......@@ -903,7 +947,7 @@ class DecoderLayer(nn.Module):
if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32,
max_distance=128,
num_heads=self.num_heads,
num_heads=self.num_attention_heads,
dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling(
1.0, 'fan_avg', 'uniform'),
......@@ -928,7 +972,8 @@ class DecoderLayer(nn.Module):
x = inputs
# Self-attention block
x = MultiHeadAttention(num_heads=self.num_heads,
x = MultiHeadAttention(num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
......@@ -960,7 +1005,8 @@ class DecoderLayer(nn.Module):
if self.apply_residual_connection_post_layernorm:
residual = y
y = MultiHeadAttention(num_heads=self.num_heads,
y = MultiHeadAttention(num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
......@@ -1012,6 +1058,9 @@ class DecoderLayer(nn.Module):
def make_causal_mask(batch, seqlen, dtype=jnp.uint8):
"""
Generate causal mask
"""
shape = (batch, seqlen)
idxs = jnp.broadcast_to(jnp.arange(shape[-1], dtype=jnp.int32), shape)
......@@ -1022,6 +1071,9 @@ def make_causal_mask(batch, seqlen, dtype=jnp.uint8):
def make_self_mask(batch, seqlen, dtype=jnp.uint8):
"""
Generate attention mask
"""
shape = (batch, seqlen)
mask = jnp.ones((*shape, shape[-1]))
mask = jnp.expand_dims(mask, axis=-3)
......@@ -1057,7 +1109,7 @@ def assert_allclose(
dtype = actual.dtype
# Determine tolerances
tols = dict()
tols = {}
if rtol is None or atol is None:
tols = dtype_tols(dtype)
if rtol is not None:
......
......@@ -573,14 +573,14 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_QKV->data.dtype;
const auto QKV_type = input_QKV->data.dtype;
void *devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0;
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = 2 * num_attn_heads * head_dim;
stride = typeToSize(QKV_type) * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = 2 * head_dim;
stride = typeToSize(QKV_type) * head_dim;
}
void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
......@@ -677,14 +677,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_QKV->data.dtype;
void *devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0;
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = 2 * num_attn_heads * head_dim;
stride = typeToSize(QKV_type) * num_attn_heads * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = 2 * head_dim;
stride = typeToSize(QKV_type) * head_dim;
}
void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
......@@ -712,7 +713,6 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const auto qkv_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_attn_heads,
......@@ -723,7 +723,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlens, devPtrCuSeqlens,
get_cudnn_fe_dtype(qkv_type), workspace->data.dptr,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
if (workspace_size > 0) {
......@@ -750,15 +750,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0;
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = 2 * num_attn_heads * head_dim;
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = 2 * head_dim;
stride = typeToSize(QKV_type) * head_dim;
}
void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
......@@ -860,15 +860,14 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto stride = 0;
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = 2 * num_attn_heads * head_dim;
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = 2 * head_dim;
stride = typeToSize(QKV_type) * head_dim;
}
void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
......@@ -935,7 +934,7 @@ void fused_attn_arbitrary_seqlen_fwd(
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_K->data.dptr;
void *devPtrV = input_V->data.dptr;
......
......@@ -1651,13 +1651,10 @@ class FusedAttnHelper:
def get_fused_attn_backend(self):
"""Get the fused attention kernel backend"""
return transformer_engine_jax.get_fused_attn_backend(jax_dtype_to_te_dtype(self.q_type),
jax_dtype_to_te_dtype(self.kv_type),
self.qkv_layout, self.attn_bias_type,
self.attn_mask_type,
self.dropout_probability,
self.num_heads_q, self.num_heads_kv,
self.max_seqlen_q, self.max_seqlen_kv,
return transformer_engine_jax.get_fused_attn_backend(
jax_dtype_to_te_dtype(self.q_type), jax_dtype_to_te_dtype(self.kv_type),
self.qkv_layout, self.attn_bias_type, self.attn_mask_type, self.dropout_probability,
self.num_heads_q, self.num_heads_kv, self.max_seqlen_q, self.max_seqlen_kv,
self.head_dim)
......@@ -1701,12 +1698,11 @@ class _FusedAttnRNGStateChecker:
return seed
def generate_cu_seqlen(mask):
def generate_cu_seqlen(actual_seqlen):
"""
Generating cumsum seqlen for a batch
"""
seqlen = jnp.sum(mask == 0, axis=(-1, -2), dtype=jnp.int32)
cu_seqlen = jnp.cumsum(seqlen)
cu_seqlen = jnp.cumsum(actual_seqlen)
cu_seqlen = jnp.hstack((0, cu_seqlen))
return cu_seqlen
......@@ -1722,13 +1718,13 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
outer_primitive = None
@staticmethod
def abstract(qkv_aval, bias_aval, mask_or_cu_seqlen_aval, seed_aval, *, attn_bias_type,
def abstract(qkv_aval, bias_aval, seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Self fused attention fwd abstract
"""
# outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
del mask_or_cu_seqlen_aval, scaling_factor, is_training
# outer_primitve is seqlen, inner_primitive is cu_seqlen
del seqlen_or_cu_seqlen_aval, scaling_factor, is_training
qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
*batch_shape, max_seqlen, nqkv, num_head, head_dim = qkv_aval.shape
assert nqkv == 3
......@@ -1781,19 +1777,20 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
batch, num_head, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor,
dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(qkv, bias, squeezed_mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
def impl(qkv, bias, seqlen, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
assert SelfFusedAttnFwdPrimitive.inner_primitive is not None
cu_seqlen = generate_cu_seqlen(squeezed_mask)
cu_seqlen = generate_cu_seqlen(seqlen)
output, softmax_aux, rng_state = SelfFusedAttnFwdPrimitive.inner_primitive.bind(
qkv,
......@@ -1859,10 +1856,9 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
register_primitive(SelfFusedAttnFwdPrimitive)
def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, squeezed_mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, seqlen: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Wrapper for TE self fused attention fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
......@@ -1875,7 +1871,7 @@ def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, squeezed_mask: jnp.
bias = jnp.zeros(0, dtype=qkv.dtype)
return SelfFusedAttnFwdPrimitive.outer_primitive.bind(qkv,
bias,
squeezed_mask,
seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
......@@ -1896,14 +1892,14 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
@staticmethod
def abstract(qkv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, doutput_aval,
mask_or_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor,
seqlen_or_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
"""
Self fused attention bwd abstract
"""
del softmax_aux_aval, rng_state_aval
# outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
del mask_or_cu_seqlen_aval, attn_bias_type, attn_mask_type
# outer_primitve is seqlen, inner_primitive is cu_seqlen
del seqlen_or_cu_seqlen_aval, attn_bias_type, attn_mask_type
del scaling_factor, dropout_probability, is_training
qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
......@@ -1934,19 +1930,20 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
batch, num_head, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor,
dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(qkv, bias, softmax_aux, rng_state, output, doutput, squeezed_mask, attn_bias_type,
def impl(qkv, bias, softmax_aux, rng_state, output, doutput, seqlen, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
assert SelfFusedAttnBwdPrimitive.inner_primitive is not None
cu_seqlen = generate_cu_seqlen(squeezed_mask)
cu_seqlen = generate_cu_seqlen(seqlen)
dqkv, dbias = SelfFusedAttnBwdPrimitive.inner_primitive.bind(
qkv,
......@@ -2029,7 +2026,7 @@ register_primitive(SelfFusedAttnBwdPrimitive)
def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray,
rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray,
squeezed_mask: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
......@@ -2045,7 +2042,7 @@ def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.nd
rng_state,
output,
doutput,
squeezed_mask,
seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
......@@ -2064,13 +2061,13 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
outer_primitive = None
@staticmethod
def abstract(q_aval, kv_aval, bias_aval, q_mask_or_cu_seqlen_aval, kv_mask_or_cu_seqlen_aval,
seed_aval, *, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training):
def abstract(q_aval, kv_aval, bias_aval, q_seqlen_or_cu_seqlen_aval,
kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
"""
Cross fused attention fwd abstract
"""
# outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
# outer_primitve is seqlen, inner_primitive is cu_seqlen
del scaling_factor, is_training
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
......@@ -2083,18 +2080,17 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
assert q_dtype == kv_dtype == bias_dtype
assert q_batch_shape == kv_batch_shape
assert q_num_head == kv_num_head
assert q_head_dim == kv_head_dim
assert nkv == 2
assert q_mask_or_cu_seqlen_aval.dtype == kv_mask_or_cu_seqlen_aval.dtype
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
output_shape = q_aval.shape
output_dtype = q_dtype
backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
attn_bias_type, attn_mask_type, dropout_probability,
q_num_head, kv_num_head,
q_max_seqlen, kv_max_seqlen, q_head_dim).get_fused_attn_backend()
attn_bias_type, attn_mask_type, dropout_probability, q_num_head,
kv_num_head, q_max_seqlen, kv_max_seqlen,
q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, kv_max_seqlen)
......@@ -2128,7 +2124,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
*batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
batch = reduce(operator.mul, batch_shape)
kv_max_seqlen = kv_aval.shape[-4]
kv_max_seqlen, kv_num_head = kv_aval.shape[-4], kv_aval.shape[-2]
operands = [q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed]
operand_shapes = map(lambda x: x.type.shape, operands)
......@@ -2139,7 +2135,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim,
batch, num_head, kv_num_head, q_max_seqlen, kv_max_seqlen, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
......@@ -2148,12 +2144,12 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
return out
@staticmethod
def impl(q, kv, bias, q_squeezed_mask, kv_squeezed_mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
def impl(q, kv, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
assert CrossFusedAttnFwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask)
kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask)
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
output, softmax_aux, rng_state = CrossFusedAttnFwdPrimitive.inner_primitive.bind(
q,
......@@ -2224,9 +2220,8 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
register_primitive(CrossFusedAttnFwdPrimitive)
def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
q_squeezed_mask: jnp.ndarray, kv_squeezed_mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, q_seqlen: jnp.ndarray,
kv_seqlen: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
......@@ -2243,8 +2238,8 @@ def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
return CrossFusedAttnFwdPrimitive.outer_primitive.bind(q,
kv,
bias,
q_squeezed_mask,
kv_squeezed_mask,
q_seqlen,
kv_seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
......@@ -2296,7 +2291,7 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
*batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
batch = reduce(operator.mul, batch_shape)
kv_max_seqlen = kv_aval.shape[-4]
kv_max_seqlen, kv_num_head = kv_aval.shape[-4], kv_aval.shape[-2]
operands = [q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen]
operand_shapes = map(lambda x: x.type.shape, operands)
......@@ -2310,7 +2305,7 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
# the dropout elements are encoded in the forward auxiliary tensor
# so seed is not needed in backward
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim,
batch, num_head, kv_num_head, q_max_seqlen, kv_max_seqlen, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
......@@ -2319,13 +2314,12 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
return out
@staticmethod
def impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_squeezed_mask,
kv_squeezed_mask, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training):
def impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen,
attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training):
assert CrossFusedAttnBwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask)
kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask)
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
dq, dkv, dbias = CrossFusedAttnBwdPrimitive.inner_primitive.bind(
q,
......@@ -2417,10 +2411,9 @@ register_primitive(CrossFusedAttnBwdPrimitive)
def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray,
doutput: jnp.ndarray, q_squeezed_mask: jnp.ndarray,
kv_squeezed_mask: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Wrapper for TE cross fused attention bwd
Return the gradients of cross fused attention with packed kv input
......@@ -2435,8 +2428,8 @@ def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
rng_state,
output,
doutput,
q_squeezed_mask,
kv_squeezed_mask,
q_seqlen,
kv_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
......
......@@ -82,12 +82,12 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch,
}
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
size_t batch, size_t num_head, size_t num_gqa_groups, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{batch, num_head, q_max_seqlen, kv_max_seqlen,
head_dim, scaling_factor, dropout_probability,
bias_type, mask_type, dtype, is_training});
return PackOpaque(CustomCallFusedAttnDescriptor{
batch, num_head, num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, scaling_factor,
dropout_probability, bias_type, mask_type, dtype, is_training});
}
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
......@@ -745,8 +745,8 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
size_t head_dim) {
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_num_heads, kv_num_heads,
q_max_seqlen, kv_max_seqlen, head_dim);
mask_type, dropout_probability, q_num_heads, kv_num_heads, q_max_seqlen, kv_max_seqlen,
head_dim);
return backend;
}
......@@ -768,6 +768,7 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
......@@ -779,6 +780,9 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
NVTE_CHECK(num_head == num_gqa_groups,
"num_head should be equal to num_gqa_groups in the qkvpacked attention");
auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
......@@ -799,10 +803,10 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
// aux tensors
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_head, num_head,
q_max_seqlen, kv_max_seqlen, head_dim);
auto backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype),
qkv_layout, bias_type, mask_type, dropout_probability, num_head,
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
NVTETensorPack aux_output_tensors;
......@@ -853,6 +857,7 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
......@@ -864,6 +869,9 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
NVTE_CHECK(num_head == num_gqa_groups,
"num_head should be equal to num_gqa_groups in the qkvpacked attention");
auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim};
auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
......@@ -941,6 +949,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
......@@ -951,7 +960,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_head, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
// input tensors
......@@ -976,10 +985,10 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_head, num_head,
q_max_seqlen, kv_max_seqlen, head_dim);
auto backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype),
qkv_layout, bias_type, mask_type, dropout_probability, num_head,
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
NVTETensorPack aux_output_tensors;
......@@ -1035,6 +1044,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
......@@ -1045,7 +1055,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_head, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
......
......@@ -98,6 +98,7 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch,
struct CustomCallFusedAttnDescriptor {
size_t batch;
size_t num_head;
size_t num_gqa_groups;
size_t q_max_seqlen;
size_t kv_max_seqlen;
size_t head_dim;
......@@ -110,8 +111,8 @@ struct CustomCallFusedAttnDescriptor {
};
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
size_t batch, size_t num_head, size_t num_gqa_groups, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, DType dtype, bool is_training);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
......
......@@ -16,7 +16,6 @@ import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import dtypes
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
......@@ -198,22 +197,31 @@ def core_attention(query: Array,
batch_dim = 1 if transpose_batch_sequence else 0
assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
'q, k, v batch dims must match.')
assert query.shape[-2] == key.shape[-2] == value.shape[-2], ('q, k, v num_heads must match.')
sequence_dim = 0 if transpose_batch_sequence else 1
assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.'
if float32_logits:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
h_q, h_kv = query.shape[-2], key.shape[-2]
assert (h_q % h_kv == 0) and (h_q >= h_kv)
group_size = h_q // h_kv
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
if transpose_batch_sequence:
attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key)
attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
attn_weights = checkpoint_name(attn_weights, 'logits')
b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
attn_weights = _with_sharding_constraint(attn_weights,
(BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))
......@@ -229,6 +237,8 @@ def core_attention(query: Array,
attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype)
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate
dropout_shape = list(attn_weights.shape)
......@@ -238,9 +248,9 @@ def core_attention(query: Array,
attn_weights = attn_weights * multiplier
if transpose_batch_sequence:
return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value)
return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
......@@ -262,6 +272,14 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
The hidden dimension of each attention head.
num_heads : int
The number of attention heads
num_gqa_groups : int, default = `None`
Number of GQA groups. When `None` is present, it is equal to num_heads.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
dropout_rate : float, default = 0.0
Dropout probability for the dropout op during multi-head attention.
dropout_rng_name: str, default = 'dropout'
......@@ -321,6 +339,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
head_dim: int
num_heads: int
num_gqa_groups: int | None = None
dropout_rate: float = 0.
dropout_rng_name: str = 'dropout'
layernorm_type: str = "layernorm"
......@@ -342,6 +361,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_heads
super().__post_init__()
@nn.compact
......@@ -428,30 +449,22 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
"supported attn_mask_type = {'causal', 'padding'}")
is_self_attn = (inputs_q is inputs_kv)
is_gqa = (self.num_heads != self.num_gqa_groups)
is_qkvpack = (is_self_attn and not is_gqa)
qkv_layout = QKVLayout.BS3HD if is_self_attn else QKVLayout.BSHD_BS2HD
attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
canonicalize_dtype = dtypes.canonicalize_dtype(self.dtype)
q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1]
kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1]
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))
def _check_seqlen(seqlen):
return seqlen % 64 == 0
def _check_head_dim(head_dim):
return head_dim in [64, 128]
has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout,
attn_bias_type, attn_mask_type,
self.dropout_rate,
self.num_heads, self.num_heads,
q_seqlen, kv_seqlen, self.head_dim)
self.dropout_rate, self.num_heads,
self.num_gqa_groups, q_seqlen,
kv_seqlen, self.head_dim)
use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \
_check_seqlen(q_seqlen) and _check_seqlen(kv_seqlen) and \
_check_head_dim(self.head_dim) and \
has_fused_attn_kernel and \
enable_fused_attn
......@@ -464,17 +477,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
f"but got {self.transpose_batch_sequence}, "
if not self.fuse_qkv:
reason += f"fuse_qkv=True is required but got {self.fuse_qkv}, "
if canonicalize_dtype not in [jnp.bfloat16, jnp.float16]:
reason += f"dtype in [BF16, FP16] is required " \
f"but got dtype={canonicalize_dtype}, "
if not _check_seqlen(q_seqlen):
reason += f"q_seqlen % 64 == 0 is required " \
f"but got {q_seqlen=}, "
if not _check_seqlen(kv_seqlen):
reason += f"kv_seqlen % 64 == 0 is required " \
f"but got {kv_seqlen=}, "
if not _check_head_dim(self.head_dim):
reason += f"head_dim should be 64 or 128 but got {self.head_dim}, "
if not has_fused_attn_kernel:
reason += "no fused attention kernel is available, "
......@@ -484,7 +486,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
residual = inputs_q
if self.fuse_qkv:
if is_self_attn:
if is_qkvpack:
qkv_proj, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm,
layernorm_type=self.layernorm_type,
......@@ -515,7 +517,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
axis=-1,
features=self.num_heads * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm,
return_layernorm_output=(self.apply_residual_connection_post_layernorm
or is_self_attn),
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes=(W_FSDP_AXES, W_TP_AXES),
......@@ -525,8 +528,13 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
dtype=self.dtype,
kernel_init=query_init,
name='query')(inputs_q)
if is_self_attn:
assert ln_out is not None
inputs_kv = ln_out
kv_proj = DenseGeneral(axis=-1,
features=(2, self.num_heads * self.head_dim),
features=(2, self.num_gqa_groups * self.head_dim),
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_init=kv_init,
......@@ -542,7 +550,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kv_projection = functools.partial(
DenseGeneral,
axis=-1,
features=self.num_heads * self.head_dim,
features=self.num_gqa_groups * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_TP_AXES),
use_bias=self.use_bias,
......@@ -583,9 +591,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
query = checkpoint_name(query, 'query_proj')
key = checkpoint_name(key, 'key_proj')
value = checkpoint_name(value, 'value_proj')
query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim))
key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim))
value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim))
query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
qkv_sharding_constraint = \
(SEQLEN_AXES, BATCH_AXES, HEAD_AXES, HIDDEN_AXES) \
if self.transpose_batch_sequence \
......@@ -650,7 +658,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
# ensure the old key never used
del dropout_rng
if is_self_attn:
if is_qkvpack:
qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
HIDDEN_AXES)
......@@ -667,7 +675,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
else:
assert bias is None
query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim))
kv_proj = kv_proj.reshape((*kv_proj.shape[:-1], self.num_heads, self.head_dim))
kv_proj = kv_proj.reshape((*kv_proj.shape[:-1], self.num_gqa_groups, self.head_dim))
q_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
kv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
HIDDEN_AXES)
......@@ -865,6 +873,14 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Intermediate size to which input samples are projected.
num_attention_heads: int, default = 8
Number of attention heads in the transformer layer.
num_gqa_groups : int, default = `None`
Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
Indicate the type of layer normalization.
layernorm_epsilon: float, default = 1e-6
......@@ -961,6 +977,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
hidden_size: int = 512
mlp_hidden_size: int = 2048
num_attention_heads: int = 8
num_gqa_groups: int | None = None
layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
......@@ -995,6 +1012,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
if self.mlp_kernel_init is None:
self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in',
'truncated_normal')
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact
......@@ -1091,6 +1110,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
num_heads=self.num_attention_heads,
dtype=self.dtype,
head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name,
......@@ -1141,6 +1161,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
num_heads=self.num_attention_heads,
dtype=self.dtype,
head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name,
......
......@@ -40,8 +40,8 @@ class QKVLayout(Enum):
def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type,
dropout_probability, num_heads_q, num_heads_kv,
max_seqlen_q, max_seqlen_kv, head_dim):
dropout_probability, num_heads_q, num_heads_kv, max_seqlen_q,
max_seqlen_kv, head_dim):
"""
To check whether the fused attention kernel is available
"""
......@@ -83,10 +83,11 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda
seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
squeezed_mask = mask[..., 0]
mask = jnp.logical_not(mask)
actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
bias,
squeezed_mask,
actual_seqlen,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
......@@ -96,12 +97,12 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda
output = checkpoint_name(output, 'context')
softmax_aux = checkpoint_name(softmax_aux, 'context')
rng_state = checkpoint_name(rng_state, 'context')
return output, (qkv, bias, softmax_aux, rng_state, output, squeezed_mask)
return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen)
def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz):
qkv, bias, softmax_aux, rng_state, output, squeezed_mask = ctx
qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx
grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
bias,
......@@ -109,7 +110,7 @@ def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dr
rng_state,
output,
dz,
squeezed_mask,
actual_seqlen,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
......@@ -159,14 +160,19 @@ def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask:
def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
q_squeezed_mask = mask[..., 0]
kv_squeezed_mask = mask[..., 0, :]
mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else:
# When mask is padding + causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = cross_fused_attn_fwd(q,
kv,
bias,
q_squeezed_mask,
kv_squeezed_mask,
q_actual_seqlen,
kv_actual_seqlen,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
......@@ -174,12 +180,12 @@ def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mas
dropout_probability=dropout_probability,
is_training=is_training)
return output, (q, kv, bias, softmax_aux, rng_state, output, q_squeezed_mask, kv_squeezed_mask)
return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen)
def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz):
q, kv, bias, softmax_aux, rng_state, output, q_squeezed_mask, kv_squeezed_mask = ctx
q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
grad_q, grad_kv, grad_bias = cross_fused_attn_bwd(q,
kv,
......@@ -188,8 +194,8 @@ def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, d
rng_state,
output,
dz,
q_squeezed_mask,
kv_squeezed_mask,
q_actual_seqlen,
kv_actual_seqlen,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
......
......@@ -64,6 +64,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
head_dim: int = 64
num_heads: int = 16
num_gqa_groups: int | None = None
dropout_rate: float = 0.
dropout_rng_name: str = 'dropout'
layernorm_type: str = "layernorm"
......@@ -80,6 +81,11 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
scaled_query_init: bool = True
float32_logits: bool = False
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_heads
super().__post_init__()
def setup(self) -> None:
"""setup"""
super().setup()
......@@ -89,6 +95,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
dtype=self.dtype,
head_dim=self.head_dim,
num_heads=self.num_heads,
num_gqa_groups=self.num_gqa_groups,
dropout_rate=self.dropout_rate,
dropout_rng_name=self.dropout_rng_name,
layernorm_type=self.layernorm_type,
......@@ -131,6 +138,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
hidden_size: int = 512
mlp_hidden_size: int = 2048
num_attention_heads: int = 8
num_gqa_groups: int | None = None
layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
......@@ -156,6 +164,11 @@ class TransformerLayer(TransformerEngineBaseLayer):
scale_attn_logits: bool = False
scaled_query_init: bool = True
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
def setup(self) -> None:
"""setup"""
super().setup()
......@@ -186,6 +199,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
hidden_size=self.hidden_size,
mlp_hidden_size=self.mlp_hidden_size,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment