Unverified Commit 73c9f421 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

Add FP16/BF16 fused_attention support with max_seqlen=512 (#175)



* Add fused attention unit tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use NVTE_* enums
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use NVTE_Mask_Type and remove FMHADescriptor
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Move common functions to utils
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Change namespace to fused_attn
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Move fused_attn_max_512_fwd_qkvpacked under the general APIs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add fused_attn_max_512_bwd_qkvpacked
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Move fused_attn_max_512_bwd_qkvpacked under the general APIs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove redundant blank line
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix a potential bug for cu_seqlen converter
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Reformat fused_attn_max_512
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine the unfused attention warning message
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Rename to fused_attn_max_512
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove the deprecated header
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix flax import
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Rename to fused attn
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add attention related mask
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add attn_mask_type and attn_bias_type
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refactor jax primitive API
* Merge q_cu_seqlen and kv_cu_seqlen
* Remove is_causal_masking
* Replace seed with rng_state
* Add is_training argument
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove dsoftmax from the customcall
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add None guard for bias and dropout_rng
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add version guard
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add is_fused_attn_kernel_available() to correctly dispatch the attention impl
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix the merge conflict
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Adjust the code style
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add the missing blank lines
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Change the order of FADescriptor members
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance the readability of fused_attn_max_512.cu
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Generalize the input dimension unpacking
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* 16 bits fused attention requires 8.9.1
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update fused attention support matrix
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Handle None type when sharding
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Change to the padding ratio
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Performance optimization for non-bias cases
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Revert the cudnn-frontend PRIVATE keyword which was used for debugging
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Revert "Update fused attention support matrix"

This reverts commit 4effe67d0f08f733919a329ce5ab421958740f4a.
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Treat b * s as total_seqs to align ragged cases
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add FP16/BF16 max_seqlen <= 512 fused attention to the support matrix
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine test_fused_attn.py

* Replace reference code with flax.linen
* Remove unnecessary comments
* Use AttnMaskType
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Unify the cuDNN compile version
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add dropout to the support matrix
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Slightly adjust the headers
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Typo fix: remove redundant either
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Consolidating fused attention requirements
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Replace cudnn_frontend::throw_if with NVTE_CHECK for the better error line report
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Rename to fused_attn_fp16_bf16_max_seqlen_512 for the better readability
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove CUDNN_FRONTEND_UNUSED
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add more annotations to the custom calls
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent c6a4a4e0
...@@ -15,7 +15,7 @@ Prerequisites ...@@ -15,7 +15,7 @@ Prerequisites
2. `CUDA 11.8 <https://developer.nvidia.com/cuda-downloads>`__ 2. `CUDA 11.8 <https://developer.nvidia.com/cuda-downloads>`__
3. |driver link|_ supporting CUDA 11.8 or later. 3. |driver link|_ supporting CUDA 11.8 or later.
4. `cuDNN 8.1 <https://developer.nvidia.com/cudnn>`__ or later. 4. `cuDNN 8.1 <https://developer.nvidia.com/cudnn>`__ or later.
5. For FP8 fused attention, `CUDA 12.1 <https://developer.nvidia.com/cuda-downloads>`__ or later, |driver link|_ supporting CUDA 12.1 or later, and `cuDNN 8.9 <https://developer.nvidia.com/cudnn>`__ or later. 5. For FP8/FP16/BF16 fused attention, `CUDA 12.1 <https://developer.nvidia.com/cuda-downloads>`__ or later, |driver link|_ supporting CUDA 12.1 or later, and `cuDNN 8.9.1 <https://developer.nvidia.com/cudnn>`__ or later.
Transformer Engine in NGC Containers Transformer Engine in NGC Containers
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Optional
import math
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from flax.linen import combine_masks
from flax.linen import dot_product_attention
from flax.linen import make_attention_mask
from flax.linen import make_causal_mask
from jax import lax
from jax import nn as jax_nn
from jax import value_and_grad, jit
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
# Type annotations
Array = jnp.ndarray
SELF_CASES = [(32, 512, 16, 64), (32, 128, 16, 64)]
CROSS_CASES = [(32, 128, 512, 16, 64)]
DTYPES = [jnp.bfloat16, jnp.float16]
PAD_RATIO = [0.3]
def make_decoder_mask(tokens: Array) -> Array:
causal_mask = make_causal_mask(tokens)
padding_mask = make_attention_mask(tokens > 0, tokens > 0)
return combine_masks(causal_mask, padding_mask)
def jax_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
attn_mask_type = kwargs['attn_mask_type']
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_decoder_mask(q_token)
else:
mask = make_attention_mask(q_token > 0, kv_token > 0)
query, key, value = jnp.split(qkv, [1, 2], axis=-3)
query = jnp.squeeze(query)
key = jnp.squeeze(key)
value = jnp.squeeze(value)
output = dot_product_attention(query,
key,
value,
bias=bias,
mask=mask,
dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng,
dtype=qkv.dtype)
return output
def jax_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
assert q.dtype == kv.dtype
attn_mask_type = kwargs['attn_mask_type']
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
raise NotImplementedError
mask = make_attention_mask(q_token > 0, kv_token > 0)
query = q
key, value = jnp.split(kv, [1], axis=-3)
key = jnp.squeeze(key)
value = jnp.squeeze(value)
output = dot_product_attention(query,
key,
value,
bias=None,
mask=mask,
dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng,
dtype=q.dtype)
return output
def customcall_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK:
mask = make_decoder_mask(q_token)
else:
mask = make_attention_mask(q_token > 0, kv_token > 0)
# mask invert
mask = (mask == 0)
return self_fused_attn(qkv, bias, mask, dropout_rng, **kwargs)
def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
assert q.dtype == kv.dtype
if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK:
raise NotImplementedError
mask = make_attention_mask(q_token > 0, kv_token > 0)
# mask invert
mask = (mask == 0)
return cross_fused_attn(q, kv, mask, dropout_rng, **kwargs)
@pytest.mark.skipif(not is_fused_attn_kernel_available(),
reason="Fused attention kernel is not supported.")
class TestSelfFusedAttnMax512():
def set_input(self, b, s, h, d, dtype, attn_mask_type, pad_ratio):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
qkv_shape = (b, s, 3, h, d)
bias_shape = (1, h, s, s)
pad_len = int(s * pad_ratio)
self.valid_len = s - pad_len
min_val, max_val = -1, 1
self.qkv = jax.random.uniform(subkeys[0], qkv_shape, dtype, min_val, max_val)
self.bias = jax.random.uniform(subkeys[1], bias_shape, dtype, min_val, max_val)
self.q_token = jnp.concatenate((jnp.ones((b, self.valid_len)), jnp.zeros((b, pad_len))),
axis=-1)
self.kv_token = self.q_token
self.scaling_factor = 1. / math.sqrt(d)
self.dropout_probability = 0.
self.dropout_rng = jax.random.PRNGKey(0)
self.attn_bias_type = AttnBiasType.POST_SCALE_BIAS
# deterministic = not is_training
self.deterministic = False
@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio):
self.set_input(b, s, h, d, dtype=dtype, attn_mask_type=attn_mask_type, pad_ratio=pad_ratio)
primitive_out = customcall_self_fused_attn(self.qkv,
self.bias,
self.q_token,
self.kv_token,
self.dropout_rng,
attn_bias_type=self.attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability,
is_training=not self.deterministic)
reference_out = jax_self_fused_attn(self.qkv,
self.bias,
self.q_token,
self.kv_token,
self.dropout_rng,
attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability)
ref_valid, _ = jnp.split(reference_out, (self.valid_len,), axis=1)
pri_valid, pri_invalid = jnp.split(primitive_out, (self.valid_len,), axis=1)
np.testing.assert_allclose(jnp.asarray(pri_valid, np.float32),
jnp.asarray(ref_valid, np.float32),
rtol=1e-4,
atol=1e-2)
np.testing.assert_allclose(jnp.asarray(pri_invalid, jnp.float32),
jnp.zeros_like(pri_invalid, jnp.float32))
@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward_backward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio):
self.set_input(b, s, h, d, dtype=dtype, attn_mask_type=attn_mask_type, pad_ratio=pad_ratio)
def grad_func(fused_attn_max_512_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident
gradient_multiplier = 1000 if dtype == jnp.bfloat16 else 10000
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
gradient_multiplier = gradient_multiplier / 10
# Keep only valid result for the gradient
# fused_attn_max_512 output has shape (b, s, h, d)
valid_fused_attn_max_512_ret, _ = jnp.split(fused_attn_max_512_func(*args, **kwargs),
(self.valid_len,),
axis=1)
return (jnp.mean(valid_fused_attn_max_512_ret, dtype=jnp.float32) *
gradient_multiplier).astype(dtype)
kwargs = {
'attn_bias_type': self.attn_bias_type,
'attn_mask_type': attn_mask_type,
'scaling_factor': self.scaling_factor,
'dropout_probability': self.dropout_probability,
'is_training': not self.deterministic
}
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
value_and_grad(
lambda qkv, bias, q_token, kv_token, dropout_rng: grad_func(
customcall_self_fused_attn, qkv, bias, q_token, kv_token, dropout_rng, **kwargs
), (0, 1)))
jitted_reference = jit(
value_and_grad(
lambda qkv, bias, q_token, kv_token, dropout_rng: grad_func(
jax_self_fused_attn, qkv, bias, q_token, kv_token, dropout_rng, **kwargs),
(0, 1)))
primitive_out, (primitive_dqkv,
primitive_dbeta) = jitted_primitive(self.qkv, self.bias, self.q_token,
self.kv_token, self.dropout_rng)
reference_out, (reference_dqkv,
reference_dbeta) = jitted_reference(self.qkv, self.bias, self.q_token,
self.kv_token, self.dropout_rng)
np.testing.assert_allclose(jnp.asarray(primitive_out, np.float32),
jnp.asarray(reference_out, np.float32),
rtol=1e-4,
atol=1e-5)
valid_primitive_dqkv, invalid_primitive_dqkv = jnp.split(primitive_dqkv, (self.valid_len,),
axis=1)
valid_reference_dqkv, invalid_reference_dqkv = jnp.split(reference_dqkv, (self.valid_len,),
axis=1)
# dQ
np.testing.assert_allclose(jnp.asarray(valid_primitive_dqkv[:, :, 0], np.float32),
jnp.asarray(valid_reference_dqkv[:, :, 0], np.float32),
rtol=1e-4,
atol=1e-5)
# dK
np.testing.assert_allclose(jnp.asarray(valid_primitive_dqkv[:, :, 1], np.float32),
jnp.asarray(valid_reference_dqkv[:, :, 1], np.float32),
rtol=1e-4,
atol=1e-5)
# dV
np.testing.assert_allclose(jnp.asarray(valid_primitive_dqkv[:, :, 2], np.float32),
jnp.asarray(valid_reference_dqkv[:, :, 2], np.float32),
rtol=1e-4,
atol=1e-5)
assert jnp.allclose(invalid_primitive_dqkv, invalid_reference_dqkv)
# Padded part should be 0s
assert jnp.allclose(invalid_primitive_dqkv, jnp.zeros_like(invalid_primitive_dqkv))
# dbeta valid part
np.testing.assert_allclose(
jnp.asarray(primitive_dbeta[:, :, :self.valid_len, :self.valid_len], np.float32),
jnp.asarray(reference_dbeta[:, :, :self.valid_len, :self.valid_len], np.float32),
rtol=1e-4,
atol=3e-5)
# dbeta padded part
np.testing.assert_allclose(
jnp.asarray(primitive_dbeta[:, :, self.valid_len:, self.valid_len:], np.float32),
jnp.asarray(reference_dbeta[:, :, self.valid_len:, self.valid_len:], np.float32))
assert jnp.allclose(primitive_dbeta[:, :, self.valid_len:, self.valid_len:],
jnp.zeros_like(primitive_dbeta[:, :, self.valid_len:, self.valid_len:]))
@pytest.mark.skipif(not is_fused_attn_kernel_available(),
reason="Fused attention kernel is not supported.")
class TestCrossFusedAttnMax512():
def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mask_type, pad_ratio):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
q_shape = (b, s_q, h, d)
kv_shape = (b, s_kv, 2, h, d)
q_pad_len = int(s_q * pad_ratio)
kv_pad_len = int(s_kv * pad_ratio)
self.q_valid_len = s_q - q_pad_len
self.kv_valid_len = s_kv - kv_pad_len
min_val, max_val = -1, 1
self.q = jax.random.uniform(subkeys[0], q_shape, dtype, min_val, max_val)
self.kv = jax.random.uniform(subkeys[1], kv_shape, dtype, min_val, max_val)
self.q_token = jnp.concatenate((jnp.ones((b, self.q_valid_len)), jnp.zeros((b, q_pad_len))),
axis=-1)
self.kv_token = jnp.concatenate((jnp.ones((b, self.kv_valid_len)), jnp.zeros(
(b, kv_pad_len))),
axis=-1)
self.scaling_factor = 1. / math.sqrt(d)
self.dropout_probability = 0.
self.dropout_rng = jax.random.PRNGKey(0)
self.attn_bias_type = AttnBiasType.NO_BIAS
# deterministic = not is_training
self.deterministic = False
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES)
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward(self, b, s_q, s_kv, h, d, dtype, attn_mask_type, pad_ratio):
self.set_input(b,
s_q,
s_kv,
h,
d,
dtype=dtype,
attn_mask_type=attn_mask_type,
pad_ratio=pad_ratio)
primitive_out = customcall_cross_fused_attn(self.q,
self.kv,
self.q_token,
self.kv_token,
self.dropout_rng,
attn_bias_type=self.attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability,
is_training=not self.deterministic)
reference_out = jax_cross_fused_attn(self.q,
self.kv,
self.q_token,
self.kv_token,
self.dropout_rng,
attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability)
ref_valid, _ = jnp.split(reference_out, (self.q_valid_len,), axis=1)
pri_valid, pri_invalid = jnp.split(primitive_out, (self.q_valid_len,), axis=1)
np.testing.assert_allclose(jnp.asarray(pri_valid, np.float32),
jnp.asarray(ref_valid, np.float32),
rtol=1e-4,
atol=2e-3)
np.testing.assert_allclose(jnp.asarray(pri_invalid, jnp.float32),
jnp.zeros_like(pri_invalid, jnp.float32))
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES)
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward_backward(self, b, s_q, s_kv, h, d, dtype, attn_mask_type, pad_ratio):
self.set_input(b,
s_q,
s_kv,
h,
d,
dtype=dtype,
attn_mask_type=attn_mask_type,
pad_ratio=pad_ratio)
def grad_func(fused_attn_max_512_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident
gradient_multiplier = 10000
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
gradient_multiplier = gradient_multiplier / 10
# Keep only valid result for the gradient
# fused_attn_max_512 output has shape (b, s_q, h, d)
valid_fused_attn_max_512_ret, _ = jnp.split(fused_attn_max_512_func(*args, **kwargs),
(self.q_valid_len,),
axis=1)
return (jnp.mean(valid_fused_attn_max_512_ret, dtype=jnp.float32) *
gradient_multiplier).astype(dtype)
kwargs = {
'attn_bias_type': self.attn_bias_type,
'attn_mask_type': attn_mask_type,
'scaling_factor': self.scaling_factor,
'dropout_probability': self.dropout_probability,
'is_training': not self.deterministic
}
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
value_and_grad(
lambda q, kv, q_token, kv_token, dropout_rng: grad_func(
customcall_cross_fused_attn, q, kv, q_token, kv_token, dropout_rng, **kwargs),
(0, 1)))
jitted_reference = jit(
value_and_grad(
lambda q, kv, q_token, kv_token, dropout_rng: grad_func(
jax_cross_fused_attn, q, kv, q_token, kv_token, dropout_rng, **kwargs), (0, 1)))
primitive_out, (primitive_dq,
primitive_dkv) = jitted_primitive(self.q, self.kv, self.q_token,
self.kv_token, self.dropout_rng)
reference_out, (reference_dq,
reference_dkv) = jitted_reference(self.q, self.kv, self.q_token,
self.kv_token, self.dropout_rng)
np.testing.assert_allclose(jnp.asarray(primitive_out, np.float32),
jnp.asarray(reference_out, np.float32),
rtol=1e-4,
atol=1e-5)
valid_primitive_dq, invalid_primitive_dq = jnp.split(primitive_dq, (self.q_valid_len,),
axis=1)
valid_reference_dq, invalid_reference_dq = jnp.split(reference_dq, (self.q_valid_len,),
axis=1)
valid_primitive_dkv, invalid_primitive_dkv = jnp.split(primitive_dkv, (self.kv_valid_len,),
axis=1)
valid_reference_dkv, invalid_reference_dkv = jnp.split(reference_dkv, (self.kv_valid_len,),
axis=1)
# dQ
np.testing.assert_allclose(jnp.asarray(valid_primitive_dq, np.float32),
jnp.asarray(valid_reference_dq, np.float32),
rtol=1e-4,
atol=1e-5)
# dK
np.testing.assert_allclose(jnp.asarray(valid_primitive_dkv[:, :, 0], np.float32),
jnp.asarray(valid_reference_dkv[:, :, 0], np.float32),
rtol=1e-4,
atol=1e-5)
# dV
np.testing.assert_allclose(jnp.asarray(valid_primitive_dkv[:, :, 1], np.float32),
jnp.asarray(valid_reference_dkv[:, :, 1], np.float32),
rtol=1e-4,
atol=1e-5)
assert jnp.allclose(invalid_primitive_dq, invalid_reference_dq)
assert jnp.allclose(invalid_primitive_dkv, invalid_reference_dkv)
# Padded part should be 0s
assert jnp.allclose(invalid_primitive_dq, jnp.zeros_like(invalid_primitive_dq))
assert jnp.allclose(invalid_primitive_dkv, jnp.zeros_like(invalid_primitive_dkv))
...@@ -54,7 +54,7 @@ def compare_frozen_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): ...@@ -54,7 +54,7 @@ def compare_frozen_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
err_msg=f"{key=} is not close") err_msg=f"{key=} is not close")
DATA_SHAPE = [(128, 32, 512), (512, 32, 512)] # (seqlen, batch, emb_dim) DATA_SHAPE = [(32, 128, 1024), (32, 512, 1024)] # (batch, seqlen, emb_dim)
DTYPE = [jnp.float32, jnp.bfloat16] DTYPE = [jnp.float32, jnp.bfloat16]
FP8_FORMATS = [Format.E4M3, Format.HYBRID] FP8_FORMATS = [Format.E4M3, Format.HYBRID]
...@@ -68,6 +68,7 @@ _KEY_OF_FUSE_MLP_WI = "fuse_mlp_wi" ...@@ -68,6 +68,7 @@ _KEY_OF_FUSE_MLP_WI = "fuse_mlp_wi"
_KEY_OF_LAYERNORM_TYPE = 'layernorm_type' _KEY_OF_LAYERNORM_TYPE = 'layernorm_type'
_KEY_OF_ZERO_CENTERED_GAMMA = 'zero_centered_gamma' _KEY_OF_ZERO_CENTERED_GAMMA = 'zero_centered_gamma'
_KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence' _KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence'
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
BASE_ATTRS = {_KEY_OF_TRANSPOSE_BS: True} BASE_ATTRS = {_KEY_OF_TRANSPOSE_BS: True}
...@@ -99,6 +100,13 @@ ATTRS = [{ ...@@ -99,6 +100,13 @@ ATTRS = [{
_KEY_OF_DROPOUT_RATE: 0.0, _KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')), _KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True _KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True
}] }]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
...@@ -129,12 +137,15 @@ class TestEncoderLayer: ...@@ -129,12 +137,15 @@ class TestEncoderLayer:
return ref, flax.core.frozen_dict.FrozenDict(unfreeze_target) return ref, flax.core.frozen_dict.FrozenDict(unfreeze_target)
def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
batch, seqlen = data_shape[:2]
if transpose_batch_sequence:
data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
sequence_dim = 0 if transpose_batch_sequence else 1
data_rng, init_rng, apply_rng = generate_test_rngs() data_rng, init_rng, apply_rng = generate_test_rngs()
inputs = (jax.random.normal(data_rng, data_shape, dtype),) inputs = (jax.random.normal(data_rng, data_shape, dtype),)
batch = data_shape[1] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[0]
seqlen = data_shape[0] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[1]
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
ref_masks = (1 - padded_mask,) ref_masks = (1 - padded_mask,)
test_masks = (None, padded_mask) # The second arg of Transformer is encoded tokens. test_masks = (None, padded_mask) # The second arg of Transformer is encoded tokens.
...@@ -149,7 +160,6 @@ class TestEncoderLayer: ...@@ -149,7 +160,6 @@ class TestEncoderLayer:
else: else:
te_layer_attrs[k] = v te_layer_attrs[k] = v
ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs) ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs)
sequence_dim = 0
layer_cls = partial(TransformerLayer, layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,), hidden_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.ENCODER, layer_type=TransformerLayerType.ENCODER,
...@@ -171,12 +181,15 @@ class TestEncoderLayer: ...@@ -171,12 +181,15 @@ class TestEncoderLayer:
del data_rng, init_rng, apply_rng del data_rng, init_rng, apply_rng
def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
batch, seqlen = data_shape[:2]
if transpose_batch_sequence:
data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
sequence_dim = 0 if transpose_batch_sequence else 1
data_rng, init_rng, apply_rng = generate_test_rngs() data_rng, init_rng, apply_rng = generate_test_rngs()
inputs = (jax.random.normal(data_rng, data_shape, dtype),) inputs = (jax.random.normal(data_rng, data_shape, dtype),)
batch = data_shape[1] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[0]
seqlen = data_shape[0] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[1]
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
ref_masks = (1 - padded_mask,) ref_masks = (1 - padded_mask,)
test_masks = (None, padded_mask) # The second arg of Transformer is encoded tokens. test_masks = (None, padded_mask) # The second arg of Transformer is encoded tokens.
...@@ -191,7 +204,6 @@ class TestEncoderLayer: ...@@ -191,7 +204,6 @@ class TestEncoderLayer:
else: else:
te_layer_attrs[k] = v te_layer_attrs[k] = v
ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs) ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs)
sequence_dim = 0
layer_cls = partial(TransformerLayer, layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,), hidden_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.ENCODER, layer_type=TransformerLayerType.ENCODER,
...@@ -335,11 +347,13 @@ class TestDecoderLayer: ...@@ -335,11 +347,13 @@ class TestDecoderLayer:
return ref, flax.core.frozen_dict.FrozenDict(unfreeze_target) return ref, flax.core.frozen_dict.FrozenDict(unfreeze_target)
def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
data_rng, init_rng, apply_rng = generate_test_rngs() transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
batch, seqlen = data_shape[:2]
batch = data_shape[1] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[0] if transpose_batch_sequence:
seqlen = data_shape[0] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[1] data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
sequence_dim = 0 if transpose_batch_sequence else 1
data_rng, init_rng, apply_rng = generate_test_rngs()
inputs = (jax.random.normal(data_rng, data_shape, inputs = (jax.random.normal(data_rng, data_shape,
dtype), jax.random.normal(data_rng, data_shape, dtype)) dtype), jax.random.normal(data_rng, data_shape, dtype))
...@@ -358,7 +372,6 @@ class TestDecoderLayer: ...@@ -358,7 +372,6 @@ class TestDecoderLayer:
else: else:
te_layer_attrs[k] = v te_layer_attrs[k] = v
ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs) ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs)
sequence_dim = 0
layer_cls = partial(TransformerLayer, layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,), hidden_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER, layer_type=TransformerLayerType.DECODER,
...@@ -379,11 +392,13 @@ class TestDecoderLayer: ...@@ -379,11 +392,13 @@ class TestDecoderLayer:
del data_rng, init_rng, apply_rng del data_rng, init_rng, apply_rng
def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
data_rng, init_rng, apply_rng = generate_test_rngs() transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
batch, seqlen = data_shape[:2]
batch = data_shape[1] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[0] if transpose_batch_sequence:
seqlen = data_shape[0] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[1] data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
sequence_dim = 0 if transpose_batch_sequence else 1
data_rng, init_rng, apply_rng = generate_test_rngs()
inputs = (jax.random.normal(data_rng, data_shape, inputs = (jax.random.normal(data_rng, data_shape,
dtype), jax.random.normal(data_rng, data_shape, dtype)) dtype), jax.random.normal(data_rng, data_shape, dtype))
...@@ -402,7 +417,6 @@ class TestDecoderLayer: ...@@ -402,7 +417,6 @@ class TestDecoderLayer:
else: else:
te_layer_attrs[k] = v te_layer_attrs[k] = v
ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs) ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs)
sequence_dim = 0
layer_cls = partial(TransformerLayer, layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,), hidden_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER, layer_type=TransformerLayerType.DECODER,
......
...@@ -12,6 +12,7 @@ list(APPEND transformer_engine_SOURCES ...@@ -12,6 +12,7 @@ list(APPEND transformer_engine_SOURCES
transpose/transpose_fusion.cu transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu transpose/multi_cast_transpose.cu
activation/gelu.cu activation/gelu.cu
fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu
fused_attn/fused_attn_fp8.cu fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp fused_attn/fused_attn.cpp
fused_attn/utils.cu fused_attn/utils.cu
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "transformer_engine/fused_attn.h" #include "transformer_engine/fused_attn.h"
#include "../common.h" #include "../common.h"
#include "utils.h" #include "utils.h"
#include "fused_attn_fp16_bf16_max_seqlen_512.h"
#include "fused_attn_fp8.h" #include "fused_attn_fp8.h"
// NVTE fused attention FWD FP8 with packed QKV // NVTE fused attention FWD FP8 with packed QKV
...@@ -26,6 +27,7 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -26,6 +27,7 @@ void nvte_fused_attn_fwd_qkvpacked(
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens); const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state); const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state);
const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV); const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV);
...@@ -35,15 +37,17 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -35,15 +37,17 @@ void nvte_fused_attn_fwd_qkvpacked(
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace); Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// QKV shape is [total_seqs, 3, h, d] // QKV shape is [total_seqs, 3, h, d]
auto ndim = input_QKV->data.shape.size();
size_t b = input_cu_seqlens->data.shape[0] - 1; size_t b = input_cu_seqlens->data.shape[0] - 1;
size_t h = input_QKV->data.shape[2]; size_t h = input_QKV->data.shape[ndim - 2];
size_t d = input_QKV->data.shape[3]; size_t d = input_QKV->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_QKV->data.dtype; const DType QKV_type = input_QKV->data.dtype;
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen <= 512)) { && (max_seqlen <= 512)) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
// FP8 API doesn't use input_Bias, bias_type or attn_mask_type // FP8 API doesn't use input_Bias, bias_type or attn_mask_type
fused_attn_fwd_fp8_qkvpacked( fused_attn_fwd_fp8_qkvpacked(
b, max_seqlen, h, d, b, max_seqlen, h, d,
...@@ -58,7 +62,31 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -58,7 +62,31 @@ void nvte_fused_attn_fwd_qkvpacked(
#endif #endif
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen <= 512)) { && (max_seqlen <= 512)) {
NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n"); #if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_qkvpacked(
b,
max_seqlen,
h,
d,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_QKV,
input_Bias,
output_O,
Aux_Output_Tensors,
input_cu_seqlens,
input_rng_state,
wkspace,
stream,
handle);
#else
NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (max_seqlen > 512) { } else if (max_seqlen > 512) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else { } else {
...@@ -84,6 +112,7 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -84,6 +112,7 @@ void nvte_fused_attn_bwd_qkvpacked(
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens); const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens);
const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV); const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV);
const Tensor *input_O = reinterpret_cast<const Tensor*>(O); const Tensor *input_O = reinterpret_cast<const Tensor*>(O);
...@@ -95,9 +124,12 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -95,9 +124,12 @@ void nvte_fused_attn_bwd_qkvpacked(
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace); Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// QKV shape is [total_seqs, 3, h, d] // QKV shape is [total_seqs, 3, h, d]
auto ndim = input_QKV->data.shape.size();
size_t b = input_cu_seqlens->data.shape[0] - 1; size_t b = input_cu_seqlens->data.shape[0] - 1;
size_t h = input_QKV->data.shape[2]; size_t h = input_QKV->data.shape[ndim - 2];
size_t d = input_QKV->data.shape[3]; size_t d = input_QKV->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_QKV->data.dtype; const DType QKV_type = input_QKV->data.dtype;
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
...@@ -107,7 +139,7 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -107,7 +139,7 @@ void nvte_fused_attn_bwd_qkvpacked(
const Tensor *input_M = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[0]); const Tensor *input_M = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]); const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]); const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
// FP8 API doesn't use input_dBias, bias_type or attn_mask_type // FP8 API doesn't use input_dBias, bias_type or attn_mask_type
fused_attn_bwd_fp8_qkvpacked( fused_attn_bwd_fp8_qkvpacked(
b, max_seqlen, h, d, b, max_seqlen, h, d,
...@@ -124,7 +156,30 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -124,7 +156,30 @@ void nvte_fused_attn_bwd_qkvpacked(
#endif #endif
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen <= 512)) { && (max_seqlen <= 512)) {
NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n"); #if (CUDNN_VERSION >= 8901)
fused_attn_max_512_bwd_qkvpacked(
b,
max_seqlen,
h,
d,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_QKV,
input_dO,
Aux_CTX_Tensors,
output_dQKV,
output_dBias,
input_cu_seqlens,
wkspace,
stream,
handle);
#else
NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (max_seqlen > 512) { } else if (max_seqlen > 512) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else { } else {
...@@ -161,9 +216,13 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -161,9 +216,13 @@ void nvte_fused_attn_fwd_kvpacked(
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace); Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// Q shape is [total_seqs, h, d] // Q shape is [total_seqs, h, d]
// KV shape is [total_seqs, h, d]
auto ndim = input_Q->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[1]; size_t h = input_Q->data.shape[ndim - 2];
size_t d = input_Q->data.shape[2]; size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_Q->data.dtype; const DType QKV_type = input_Q->data.dtype;
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
...@@ -171,7 +230,34 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -171,7 +230,34 @@ void nvte_fused_attn_fwd_kvpacked(
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { && (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n"); #if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_kvpacked(
b,
max_seqlen_q,
max_seqlen_kv,
h,
d,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_Q,
input_KV,
input_Bias,
output_O,
Aux_Output_Tensors,
input_cu_seqlens_q,
input_cu_seqlens_kv,
input_rng_state,
wkspace,
stream,
handle);
#else
NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) { } else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else { } else {
...@@ -214,16 +300,48 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -214,16 +300,48 @@ void nvte_fused_attn_bwd_kvpacked(
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace); Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// Q shape is [total_seqs, h, d] // Q shape is [total_seqs, h, d]
// KV shape is [total_seqs, h, d]
auto ndim = input_Q->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[1]; size_t h = input_Q->data.shape[ndim - 2];
size_t d = input_Q->data.shape[2]; size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_Q->data.dtype; const DType QKV_type = input_Q->data.dtype;
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { && (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { && (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n"); #if (CUDNN_VERSION >= 8901)
fused_attn_max_512_bwd_kvpacked(
b,
max_seqlen_q,
max_seqlen_kv,
h,
d,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_Q,
input_KV,
input_dO,
Aux_CTX_Tensors,
output_dQ,
output_dKV,
output_dBias,
input_cu_seqlens_q,
input_cu_seqlens_kv,
wkspace,
stream,
handle);
#else
NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) { } else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else { } else {
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "fused_attn_fp16_bf16_max_seqlen_512.h"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cudnn_frontend.h>
#include <map>
#include <vector>
#include "../common.h"
#include "utils.h"
#if (CUDNN_VERSION >= 8901)
#define Q_ID 1
#define K_ID 2
#define V_ID 3
#define O_ID 4
#define S_ID 5
#define B_ID 6
#define D_CONST_ID 7
#define S_CONST_ID 8
#define Q_SEQLEN_ID 9
#define K_SEQLEN_ID 10
#define dQ_ID 11
#define dK_ID 12
#define dV_ID 13
#define dO_ID 14
#define MASK_VAL_ID 15
#define dS_ID 16
#define dBias_ID 17
#define VIRTUAL_ID 20
namespace transformer_engine {
namespace fused_attn {
static void createScale(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
// NOLINTNEXTLINE(runtime/references)
std::vector<cudnn_frontend::Operation> &ops) {
// scale
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
int64_t k_dim[4] = {b, h, d, s_kv};
int64_t k_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout,
NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose);
auto scaleTensor =
tensor_create(tensorType, S_CONST_ID, scale_dim, scale_stride, false, true); // is by value
auto kTensor = tensor_create(tensorType, K_ID, k_dim, k_stride, false, false);
auto afterScaleKTensor =
tensor_create(tensorType, VIRTUAL_ID, k_dim, k_stride, true, false); // is virtual
// Define the scale descriptor
auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a Scale Node.
auto scale_op = binary_pw_op_create(kTensor, scaleTensor, afterScaleKTensor, scaleDesc);
ops.push_back(std::move(scale_op));
}
static cudnn_frontend::Tensor createBMM1(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
bool zero_s,
// NOLINTNEXTLINE(runtime/references)
std::vector<cudnn_frontend::Operation> &ops) {
// Creates the necessary tensor descriptors
int64_t q_dim[4] = {b, h, s_q, d};
int64_t q_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
int64_t k_dim[4] = {b, h, d, s_kv};
int64_t k_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout,
NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose);
int64_t p_dim[4] = {b, h, s_q, s_kv};
int64_t p_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, p_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix);
int64_t seqlen_dim[4] = {b, 1, 1, 1};
int64_t seqlen_stride[4] = {1, 1, 1, 1};
auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false);
auto afterScaleKTensor =
tensor_create(tensorType, VIRTUAL_ID, k_dim, k_stride, true, false); // is virtual
// first GEMM output
auto pTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 1, p_dim, p_stride, true,
false); // is virtual
auto seqlenQTensor =
tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
auto seqlenKTensor =
tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
// Define the matmul 1 desc
// set padding value optionally to 0 for writing zeros to S tensor (if not set, old behaviour)
auto matmul_1_Desc =
cudnn_frontend::MatMulDescBuilder().setComputeType(CUDNN_DATA_FLOAT).build();
if (zero_s) {
matmul_1_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build();
}
// Create a matmul 1 Node
auto matmul_op1 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(qTensor)
.setbMatDesc(afterScaleKTensor)
.setcMatDesc(pTensor)
.setmOverrideDesc(seqlenQTensor)
.setnOverrideDesc(seqlenKTensor)
.setmatmulDesc(matmul_1_Desc)
.build();
ops.push_back(std::move(matmul_op1));
return pTensor;
}
static cudnn_frontend::Tensor createBias(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
// NOLINTNEXTLINE(runtime/references)
std::vector<cudnn_frontend::Operation> &ops,
cudnn_frontend::Tensor const &prevBlockOutputTensor) {
NVTE_CHECK(ops.size() != 0, "Bias op constructed incorrectly as the first one.");
int64_t b_dim[4] = {1, h, s_q, s_kv};
int64_t b_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t afterBias_dim[4] = {b, h, s_q, s_kv};
int64_t afterBias_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, afterBias_stride, layout,
NVTE_QKV_Matrix::NVTE_S_Matrix);
// bias
auto bTensor = tensor_create(tensorType, B_ID, b_dim, b_stride, false, false);
// output
auto afterBiasTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 50, afterBias_dim,
afterBias_stride, true, false); // is virtual
// Define the bias descriptor
auto biasDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ADD);
// Create a Bias Node.
auto bias_op = binary_pw_op_create(prevBlockOutputTensor, bTensor, afterBiasTensor, biasDesc);
ops.push_back(std::move(bias_op));
return afterBiasTensor;
}
static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, NVTE_Mask_Type mask_type,
cudnnDataType_t tensorType,
// NOLINTNEXTLINE(runtime/references)
std::vector<cudnn_frontend::Operation> &ops,
cudnn_frontend::Tensor const &prevBlockOutputTensor,
bool is_bprop) {
NVTE_CHECK(ops.size() != 0, "Padding mask constructed incorrectly as the first one.");
// subtraction output
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t seqlen_dim[4] = {b, 1, 1, 1};
int64_t seqlen_stride[4] = {1, 1, 1, 1};
int64_t maskVal_dim[4] = {1, 1, 1, 1};
int64_t maskVal_stride[4] = {1, 1, 1, 1};
// mask value to put in the masked pixels
auto maskValTensor = tensor_create(CUDNN_DATA_FLOAT, MASK_VAL_ID, maskVal_dim, maskVal_stride,
false, true); // is by value
auto seqlenQTensor =
tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
auto seqlenKTensor =
tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
// gen index row output
auto rowIndexTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 100, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// gen index column output
auto columnIndexTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 101, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// less than row output
auto lessThanRowTensor =
tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 102, afterBMM1_dim, afterBMM1_stride, true,
false); // is virtual
// less than column output
auto lessThanColTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 103, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// padding mask (lessthanRow && lessthanCol)
auto paddingMaskTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 104, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// row >= col check for causal mask
auto rowGreaterColTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 105, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// create causal mask (padding && row >= col)
auto causalMaskTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 106, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// output after masking
int64_t maskOutputTensor_id = VIRTUAL_ID + 107;
int64_t maskOutputTensor_virtual = true;
cudnnDataType_t maskOutputTensor_dataType = CUDNN_DATA_FLOAT;
auto maskOutputTensor_reorderType =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_NONE;
if (is_bprop) {
maskOutputTensor_id = dS_ID;
maskOutputTensor_virtual = false;
maskOutputTensor_dataType = tensorType;
maskOutputTensor_reorderType =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16;
}
auto maskOutputTensor =
cudnn_frontend::TensorBuilder()
.setDim(4, afterBMM1_dim)
.setStride(4, afterBMM1_stride)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setByValue(false)
.setDataType(maskOutputTensor_dataType)
.setVirtual(maskOutputTensor_virtual)
.setId(maskOutputTensor_id)
.setReorderType(maskOutputTensor_reorderType)
.build();
// Define the gen index for row descriptor
auto genIndexRowDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setAxis(2)
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a gen index Node.
auto genIndexRow_op =
unary_pw_op_create(prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc);
// Define the gen index for row descriptor
auto genIndexColumnDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setAxis(3)
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a gen index Node.
auto genIndexColumn_op =
unary_pw_op_create(prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc);
// Define the less than comparison for row descriptor
auto lessThanRowDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT);
// Create a less than comparison for row Node.
auto lessThanRow_op =
binary_pw_op_create(rowIndexTensor, seqlenQTensor, lessThanRowTensor, lessThanRowDesc);
// Define the less than comparison for column descriptor
auto lessThanColDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT);
// Create a less than comparison for col Node.
auto lessThanCol_op =
binary_pw_op_create(columnIndexTensor, seqlenKTensor, lessThanColTensor, lessThanColDesc);
// Define the less than comparison for column descriptor
auto paddingMaskAndDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_LOGICAL_AND);
// Create a and node for combining lessThanRow and lessThanCol
auto paddingMaskAnd_op = binary_pw_op_create(lessThanRowTensor, lessThanColTensor,
paddingMaskTensor, paddingMaskAndDesc);
// Define the greater than equal to comparison descriptor
auto rowGreaterColDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_CMP_GE);
// Create a greater than equal to Node.
auto rowGreaterCol_op = binary_pw_op_create(rowIndexTensor, columnIndexTensor,
rowGreaterColTensor, rowGreaterColDesc);
// Define the and to create causal mask descriptor
auto causalMaskAndDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_LOGICAL_AND);
// Create a causal Mask Node.
auto causalMaskAnd_op = binary_pw_op_create(paddingMaskTensor, rowGreaterColTensor,
causalMaskTensor, causalMaskAndDesc);
/////////////////// Apply the mask //////////////////////////
auto maskTensor = (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
? std::move(causalMaskTensor)
: std::move(paddingMaskTensor);
// Define the binary select to perform masking descriptor
auto maskDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT);
// Create a binary select Node.
auto mask_op = ternary_pw_op_create(prevBlockOutputTensor, maskValTensor, maskTensor,
maskOutputTensor, maskDesc);
ops.push_back(std::move(genIndexRow_op));
ops.push_back(std::move(genIndexColumn_op));
ops.push_back(std::move(lessThanRow_op));
ops.push_back(std::move(lessThanCol_op));
ops.push_back(std::move(paddingMaskAnd_op));
if (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) {
ops.push_back(std::move(rowGreaterCol_op));
ops.push_back(std::move(causalMaskAnd_op));
}
ops.push_back(std::move(mask_op));
return maskOutputTensor;
}
static cudnn_frontend::Tensor createSoftmaxForward(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout,
bool enable_dropout, bool softmax_output_virtual, cudnnDataType_t tensorType,
// NOLINTNEXTLINE(runtime/references)
std::vector<cudnn_frontend::Operation> &ops,
cudnn_frontend::Tensor const &prevBlockOutputTensor) {
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t afterReduction_dim[4] = {b, h, s_q, 1};
int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1};
cudnnDataType_t softmaxOutputType =
(enable_dropout || softmax_output_virtual) ? CUDNN_DATA_FLOAT : tensorType;
uint64_t softmaxOutputName = softmax_output_virtual ? VIRTUAL_ID + 154 : S_ID;
// max (x)
auto afterMaxReductionTensor =
tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 150, afterReduction_dim, afterReduction_stride,
true, false); // is virtual
// x - max(x)
auto afterSubtractionTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 151, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// e^(x - max(x))
auto afterExponentTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 152, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual;
// sum (e^(x - max(x)))
auto afterAddReductionTensor =
tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 153, afterReduction_dim, afterReduction_stride,
true, false); // is virtual
// divide (e/ sum(e))
auto reorder_type =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16;
auto afterDivisionTensor =
cudnn_frontend::TensorBuilder()
.setDim(4, afterBMM1_dim)
.setStride(4, afterBMM1_stride)
.setId(softmaxOutputName)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(softmaxOutputType)
.setVirtual(softmax_output_virtual)
.setByValue(false)
.setReorderType(reorder_type)
.build();
// Define the reduction descriptor
auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_MAX)
.build();
// Create a reduction max Node.
auto reductionMax_op =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(prevBlockOutputTensor)
.setyDesc(afterMaxReductionTensor)
.setreductionDesc(reductionMaxDesc)
.build();
// Define the subtract descriptor
auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB);
// Create a subtract Node.
auto subtract_op = binary_pw_op_create(prevBlockOutputTensor, afterMaxReductionTensor,
afterSubtractionTensor, subtractDesc);
// Define the exponent descriptor
auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP);
// Create a exponent Node.
auto exponent_op =
unary_pw_op_create(afterSubtractionTensor, afterExponentTensor, exponentDesc);
// Define the reduction descriptor
auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
.build();
// Create a reduction add Node.
auto reductionAdd_op =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(afterExponentTensor)
.setyDesc(afterAddReductionTensor)
.setreductionDesc(reductionAddDesc)
.build();
// Define the division descriptor
auto divisionDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_DIV);
// Create a subtract Node.
auto division_op = binary_pw_op_create(afterExponentTensor, afterAddReductionTensor,
afterDivisionTensor, divisionDesc);
ops.push_back(std::move(reductionMax_op));
ops.push_back(std::move(subtract_op));
ops.push_back(std::move(exponent_op));
ops.push_back(std::move(reductionAdd_op));
ops.push_back(std::move(division_op));
return afterDivisionTensor;
}
static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, int64_t s_kv,
int64_t d, int64_t seed, double probability,
cudnnDataType_t tensorType,
// NOLINTNEXTLINE(runtime/references)
std::vector<cudnn_frontend::Operation> &ops,
cudnn_frontend::Tensor const &prevBlockOutputTensor) {
NVTE_CHECK(ops.size() != 0, "Dropout DAG constructed incorrectly as the first one");
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
// mask for the dropout
auto dropoutMaskTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 200, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
auto reorder_type =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16;
// after dropout tensor
auto afterDropoutTensor =
cudnn_frontend::TensorBuilder()
.setDim(4, afterBMM1_dim)
.setStride(4, afterBMM1_stride)
.setId(S_ID)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(tensorType)
.setVirtual(false)
.setByValue(false)
.setReorderType(reorder_type)
.build();
// scale after dropout
auto scaleDropoutTensor = tensor_create(tensorType, D_CONST_ID, scale_dim, scale_stride, false,
true); // is by value
// after Scale
auto afterScaleTensor = tensor_create(tensorType, VIRTUAL_ID + 201, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
// Define the reduction descriptor
auto rngDesc = cudnn_frontend::RngDescBuilder()
.setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI)
.setBernoulliDistProbability(1.0 - probability)
.build();
// Create a rng Node.
auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR)
.setyDesc(dropoutMaskTensor)
.setSeed(seed)
.setRngDesc(rngDesc)
.build();
// Define the multiply mask descriptor
auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply mask Node.
auto maskMul_op = binary_pw_op_create(prevBlockOutputTensor, dropoutMaskTensor,
afterDropoutTensor, maskMulDesc);
// Define the multiply scale descriptor
auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply mask Node.
auto scaleMul_op =
binary_pw_op_create(afterDropoutTensor, scaleDropoutTensor, afterScaleTensor, scaleMulDesc);
ops.push_back(std::move(rng_op));
ops.push_back(std::move(maskMul_op));
ops.push_back(std::move(scaleMul_op));
return afterScaleTensor;
}
static void createBMM2(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
// NOLINTNEXTLINE(runtime/references)
std::vector<cudnn_frontend::Operation> &ops,
cudnn_frontend::Tensor const &prevBlockOutputTensor) {
NVTE_CHECK(ops.size() != 0, "BMM2 op constructed incorrectly as the first one");
int64_t seqlen_dim[4] = {b, 1, 1, 1};
int64_t seqlen_stride[4] = {1, 1, 1, 1};
int64_t v_dim[4] = {b, h, s_kv, d};
int64_t v_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
int64_t o_dim[4] = {b, h, s_q, d};
int64_t o_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
auto seqlenQTensor =
tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
auto seqlenKTensor =
tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
auto vTensor = tensor_create(tensorType, V_ID, v_dim, v_stride, false, false);
// second GEMM output
auto oTensor = tensor_create(tensorType, O_ID, o_dim, o_stride, false, false);
// Define the matmul 2 desc
// set padding value optionally to 0 for writing zeros to O tensor (if not set, old behaviour)
auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build();
// Create a matmul 2 Node
auto matmul_op2 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(prevBlockOutputTensor)
.setbMatDesc(vTensor)
.setcMatDesc(oTensor)
.setmOverrideDesc(seqlenQTensor)
.setkOverrideDesc(seqlenKTensor)
.setmatmulDesc(matmul_2_Desc)
.build();
ops.push_back(std::move(matmul_op2));
}
static cudnn_frontend::Tensor createSoftmaxBackward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv,
int64_t d, NVTE_QKV_Layout layout,
cudnnDataType_t tensorType,
// NOLINTNEXTLINE(runtime/references)
std::vector<cudnn_frontend::Operation> &ops,
cudnn_frontend::Tensor const &yTensor,
cudnn_frontend::Tensor const &dyTensor) {
NVTE_CHECK(ops.size() != 0, "Softmax backward constructed incorrectly as the first one");
int64_t p_dim[4] = {b, h, s_q, s_kv};
int64_t p_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, p_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix);
int64_t p_reduction_dim[4] = {b, h, s_q, 1};
int64_t p_reduction_stride[4];
p_reduction_stride[3] = 1;
p_reduction_stride[2] = 1;
p_reduction_stride[1] = s_q;
p_reduction_stride[0] = s_q * h;
int64_t const_dim[4] = {1, 1, 1, 1};
int64_t const_stride[4] = {1, 1, 1, 1};
// creating all tensors
auto softmaxScaleTensor =
tensor_create(CUDNN_DATA_FLOAT, S_CONST_ID, const_dim, const_stride, false, true);
auto dyMulYTensor =
tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 250, p_dim, p_stride, true, false);
auto dxAfterReductionTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 251, p_reduction_dim,
p_reduction_stride, true, false);
auto dxAfterSubtractionTensor =
tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 252, p_dim, p_stride, true, false);
auto dxUnscaleTensor =
tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 253, p_dim, p_stride, true, false);
auto dxTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 254, p_dim, p_stride, true, false);
// creating all ops
// mul (y * dy)
auto mul_1_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
auto mul_1_op = binary_pw_op_create(yTensor, dyTensor, dyMulYTensor, mul_1_desc);
// reduction add sum (y * dy)
auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
.build();
auto reductionAdd_op =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(dyMulYTensor)
.setyDesc(dxAfterReductionTensor)
.setreductionDesc(reductionAddDesc)
.build();
// subtraction (dy - sum(y * dy))
auto sub_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB);
auto sub_0_op =
binary_pw_op_create(dyTensor, dxAfterReductionTensor, dxAfterSubtractionTensor, sub_0_desc);
// mul (y * (dy - sum(y * dy)))
auto mul_2_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
auto mul_2_op =
binary_pw_op_create(yTensor, dxAfterSubtractionTensor, dxUnscaleTensor, mul_2_desc);
// mul (scale * dx)
auto mul_3_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
auto mul_3_op = binary_pw_op_create(dxUnscaleTensor, softmaxScaleTensor, dxTensor, mul_3_desc);
ops.push_back(std::move(mul_1_op));
ops.push_back(std::move(reductionAdd_op));
ops.push_back(std::move(sub_0_op));
ops.push_back(std::move(mul_2_op));
ops.push_back(std::move(mul_3_op));
return dxTensor;
}
void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
bool is_training, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, void *devPtrQ, void *devPtrK,
void *devPtrV, void *devPtrS, void *devPtrO, void *devPtrBias,
void *devCuSeqlenQ, void *devCuSeqlenK, void *workspace,
size_t *workspace_size, cudnnDataType_t tensorType,
cudaStream_t stream, cudnnHandle_t handle) {
try {
constexpr int64_t seed = 0; // TODO(rewang): replace this with device seed/offset
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
FADescriptor descriptor{b, h,
s_q, s_kv,
d, scaling_factor,
is_training, dropout_probability,
layout, bias_type,
mask_type, tensorType};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static CacheType fmha_fprop_cache;
bool enable_dropout = (dropout_probability != 0.0f);
NVTE_CHECK(!enable_dropout,
"dropout probability > 0 in fused_attn_max_512 has not been implemented.");
// Get plan from cache if cache is available, otherwise create one
auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) {
// if hit, return
auto it = cache.find(descriptor);
if (it != cache.end()) {
auto plan = it->second;
return plan;
}
// otherwise, build the op_graph and the plan. Then update cache
std::vector<cudnn_frontend::Operation const *> all_ops;
std::vector<cudnn_frontend::Operation> ops;
createScale(b, h, s_q, s_kv, d, layout, tensorType, ops);
// if bias, we need to memset the S buffer to correctly computate dbias
auto zero_s = (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS);
auto bmm1_output = createBMM1(b, h, s_q, s_kv, d, layout, tensorType, zero_s, ops);
NVTE_CHECK(bias_type != NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS,
"NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS has not been implemented.");
if (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) {
createBias(b, h, s_q, s_kv, d, layout, tensorType, ops, bmm1_output);
}
auto mask_output = createMask(b, h, s_q, s_kv, d, layout, mask_type, tensorType, ops,
bmm1_output, false);
NVTE_CHECK(dropout_probability != 1.0f, "Dropout probability cannot be 1.0.");
// TODO(rewang): check whether devPtrS can be removed
bool softmax_output_virtual = enable_dropout; // || devPtrS == nullptr;
auto softmax_output =
createSoftmaxForward(b, h, s_q, s_kv, d, layout, enable_dropout,
softmax_output_virtual, tensorType, ops, mask_output);
if (dropout_probability != 0.0f) {
auto dropout_output = createDropout(b, h, s_q, s_kv, d, seed, dropout_probability,
tensorType, ops, softmax_output);
createBMM2(b, h, s_q, s_kv, d, layout, tensorType, ops, dropout_output);
} else {
createBMM2(b, h, s_q, s_kv, d, layout, tensorType, ops, softmax_output);
}
for (unsigned int i = 0; i < ops.size(); i++) {
all_ops.push_back(&ops[i]);
}
// Create an Operation Graph
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle)
.setOperationGraph(all_ops.size(), all_ops.data())
.build();
cudnn_frontend::EngineConfigList filtered_configs;
auto statuses = cudnn_frontend::get_heuristics_list<1>(
{"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true);
if (filtered_configs.size() == 0) {
cudnn_frontend::set_error_and_throw_exception(
nullptr, CUDNN_STATUS_NOT_SUPPORTED,
"run_mha_fprop: No config returned by the heuristics");
}
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(filtered_configs[0], opGraph.getTag())
.build();
cache.insert({descriptor, plan});
return plan;
};
auto plan = get_plan(fmha_fprop_cache, descriptor);
auto plan_workspace_size = plan.getWorkspaceSize();
// Exit to request upper level API to allocate memory if needed
if (workspace == nullptr) {
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
*workspace_size = plan_workspace_size + actual_seqlen_workspace_size;
return;
}
// Prepare actual seqlen
constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devActualSeqlenK = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, static_cast<const int32_t *>(devCuSeqlenQ),
static_cast<const int32_t *>(devCuSeqlenK), static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenK));
// change this if you have access to float_min
float negInfinity = -1.0E+10;
float scale_dropout = 1 / (1 - dropout_probability);
std::set<std::pair<uint64_t, void *>> data_ptrs;
// add all the data pointers to be used in the variant pack
data_ptrs.insert(std::pair<uint64_t, void *>(Q_ID, devPtrQ));
data_ptrs.insert(std::pair<uint64_t, void *>(K_ID, devPtrK));
data_ptrs.insert(std::pair<uint64_t, void *>(V_ID, devPtrV));
data_ptrs.insert(std::pair<uint64_t, void *>(Q_SEQLEN_ID, devActualSeqlenQ));
data_ptrs.insert(std::pair<uint64_t, void *>(K_SEQLEN_ID, devActualSeqlenK));
data_ptrs.insert(std::pair<uint64_t, void *>(MASK_VAL_ID, &negInfinity));
if (tensorType == CUDNN_DATA_FLOAT) {
data_ptrs.insert(std::pair<uint64_t, void *>(S_CONST_ID, &scaling_factor));
} else if (tensorType == CUDNN_DATA_HALF) {
__half cast_scaling_factor{scaling_factor};
data_ptrs.insert(std::pair<uint64_t, void *>(S_CONST_ID, &cast_scaling_factor));
} else if (tensorType == CUDNN_DATA_BFLOAT16) {
__nv_bfloat16 cast_scaling_factor{scaling_factor};
data_ptrs.insert(std::pair<uint64_t, void *>(S_CONST_ID, &cast_scaling_factor));
} else {
std::cerr << "Not supported tensorType." << std::endl;
}
data_ptrs.insert(std::pair<uint64_t, void *>(O_ID, devPtrO));
if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) {
data_ptrs.insert(std::pair<uint64_t, void *>(B_ID, devPtrBias));
}
if (devPtrS != nullptr) {
data_ptrs.insert(std::pair<uint64_t, void *>(S_ID, devPtrS));
}
if (enable_dropout) {
data_ptrs.insert(std::pair<uint64_t, void *>(D_CONST_ID, &scale_dropout));
}
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace)
.setDataPointers(data_ptrs)
.build();
NVTE_CHECK_CUDNN(
cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc()));
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
}
}
void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout, NVTE_Mask_Type mask_type,
NVTE_Bias_Type bias_type, void *devPtrQ, void *devPtrK,
void *devPtrV, void *devPtrS, void *devPtrdQ, void *devPtrdK,
void *devPtrdV, void *devPtrdO, void *devPtrdS, void *devPtrdBias,
void *devCuSeqlenQ, void *devCuSeqlenK, void *workspace,
size_t *workspace_size, cudnnDataType_t tensorType,
cudaStream_t stream, cudnnHandle_t handle) {
try {
// Create cudnn handle
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
FADescriptor descriptor{
b, h, s_q, s_kv, d, scaling_factor, true, dropout_probability,
layout, bias_type, mask_type, tensorType};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static CacheType fmha_bprop_cache;
auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) {
auto it = cache.find(descriptor);
if (it != cache.end()) {
return it->second;
}
std::vector<cudnn_frontend::Operation const *> all_ops;
std::vector<cudnn_frontend::Operation> ops;
// Creates the necessary tensor descriptors
int64_t q_dim[4] = {b, h, s_q, d};
int64_t q_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout,
NVTE_QKV_Matrix::NVTE_Q_Matrix);
int64_t k_dim[4] = {b, h, s_kv, d};
int64_t k_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, k_stride, layout,
NVTE_QKV_Matrix::NVTE_K_Matrix); // type is correct as K is not transposed
int64_t v_dim[4] = {b, h, d, s_kv};
int64_t v_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, v_stride, layout,
NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose); // type is correct as V is transposed
int64_t p_dim[4] = {b, h, s_q, s_kv};
int64_t p_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, p_stride, layout,
NVTE_QKV_Matrix::NVTE_S_Matrix);
int64_t p_transpose_dim[4] = {b, h, s_kv, s_q};
int64_t p_transpose_stride[4];
p_transpose_stride[0] = p_stride[0];
p_transpose_stride[1] = p_stride[1];
p_transpose_stride[2] = p_stride[3];
p_transpose_stride[3] = p_stride[2];
int64_t o_dim[4] = {b, h, s_q, d};
int64_t o_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout,
NVTE_QKV_Matrix::NVTE_O_Matrix);
int64_t seqlen_dim[4] = {b, 1, 1, 1};
int64_t seqlen_stride[4] = {1, 1, 1, 1};
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
// inputs to fprop
auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false);
auto kTensor = tensor_create(tensorType, K_ID, k_dim, k_stride, false, false);
auto vTensor = tensor_create(tensorType, V_ID, v_dim, v_stride, false, false);
auto seqlenQTensor = tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim,
seqlen_stride, false, false);
auto seqlenKTensor = tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim,
seqlen_stride, false, false);
// gradient of the output
auto doTensor = tensor_create(tensorType, dO_ID, o_dim, o_stride, false, false);
auto reorder_type =
cudnn_frontend::cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16;
// activation from fprop
auto pTensor =
cudnn_frontend::TensorBuilder()
.setDim(4, p_dim)
.setStride(4, p_stride)
.setId(S_ID)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(tensorType)
.setVirtual(false)
.setByValue(false)
.setReorderType(reorder_type)
.build();
// outputs from bprop
auto dqTensor = tensor_create(tensorType, dQ_ID, q_dim, q_stride, false, false);
auto dkTensor = tensor_create(tensorType, dK_ID, k_dim, k_stride, false, false);
auto dvTensor = tensor_create(tensorType, dV_ID, k_dim, k_stride, false,
false); // not transposed therefore k_dim and k_stride
////////////////////////////////////////////////////////
// start creating the ops and the intermediate tensors
auto pReshapeTensor = tensor_create(tensorType, VIRTUAL_ID + 300, p_transpose_dim,
p_transpose_stride, true, false);
// reshape to perform transpose and make pReshape
auto reshape_op =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
.setxDesc(pTensor)
.setyDesc(pReshapeTensor)
.build();
ops.push_back(std::move(reshape_op));
// scale dropout
auto dropoutScaleTensor = tensor_create(CUDNN_DATA_FLOAT, D_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
auto pAfterScaleTensor = tensor_create(tensorType, VIRTUAL_ID + 301, p_transpose_dim,
p_transpose_stride, true, false);
auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
auto scaleMul_op = binary_pw_op_create(pReshapeTensor, dropoutScaleTensor,
pAfterScaleTensor, scaleMulDesc);
ops.push_back(std::move(scaleMul_op));
// perform absolute operation to remove the mask bit
auto pTransposeAfterAbsTensor = tensor_create(
tensorType, VIRTUAL_ID + 302, p_transpose_dim, p_transpose_stride, true, false);
auto absDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ABS);
auto abs_op = unary_pw_op_create(pAfterScaleTensor, pTransposeAfterAbsTensor, absDesc);
ops.push_back(std::move(abs_op));
// matmul to calculate dvTensor
// set padding value optionally to 0 for writing zeros to dV tensor (if not set, old
// behaviour)
auto matmul_0_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build();
auto matmul_op0 =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(pTransposeAfterAbsTensor)
.setbMatDesc(doTensor)
.setcMatDesc(dvTensor)
.setmOverrideDesc(seqlenKTensor)
.setkOverrideDesc(seqlenQTensor)
.setmatmulDesc(matmul_0_Desc)
.build();
ops.push_back(std::move(matmul_op0));
// matmul to calculate dpTensor
auto dpTensor =
tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 303, p_dim, p_stride, true, false);
auto matmul_1_Desc =
cudnn_frontend::MatMulDescBuilder().setComputeType(CUDNN_DATA_FLOAT).build();
auto matmul_op1 =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(doTensor)
.setbMatDesc(vTensor)
.setcMatDesc(dpTensor)
.setmOverrideDesc(seqlenQTensor)
.setnOverrideDesc(seqlenKTensor)
.setmatmulDesc(matmul_1_Desc)
.build();
ops.push_back(std::move(matmul_op1));
// mask the values which were dropped in dropout
auto pAbsTensor =
tensor_create(tensorType, VIRTUAL_ID + 304, p_dim, p_stride, true, false);
auto p_absDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ABS);
auto p_abs_op = unary_pw_op_create(pTensor, pAbsTensor, p_absDesc);
ops.push_back(std::move(p_abs_op));
// create the dropout mask
auto zeroTensor = tensor_create(CUDNN_DATA_FLOAT, MASK_VAL_ID, scale_dim, scale_stride,
false, true); // is by value
auto dropoutMaskTensor =
tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 305, p_dim, p_stride, true, false);
auto greater_than_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_GT);
auto greater_than_0_op =
binary_pw_op_create(pTensor, zeroTensor, dropoutMaskTensor, greater_than_0_desc);
ops.push_back(std::move(greater_than_0_op));
// scale for the dropout
auto dpAfterScaleTensor =
tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 306, p_dim, p_stride, true, false);
auto mul_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
auto mul_0_op =
binary_pw_op_create(dpTensor, dropoutScaleTensor, dpAfterScaleTensor, mul_0_desc);
ops.push_back(std::move(mul_0_op));
// drop the values based on the dropout mask
auto dpAfterDropoutTensor =
tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 307, p_dim, p_stride, true, false);
auto selection_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT);
auto selection_0_op =
ternary_pw_op_create(dpAfterScaleTensor, zeroTensor, dropoutMaskTensor,
dpAfterDropoutTensor, selection_0_desc);
ops.push_back(std::move(selection_0_op));
// softmax backward
auto dsTensor = createSoftmaxBackward(b, h, s_q, s_kv, d, layout, tensorType, ops,
pAbsTensor, dpAfterDropoutTensor);
// mask
auto dsAfterMaskTensor =
createMask(b, h, s_q, s_kv, d, layout, mask_type, tensorType, ops, dsTensor, true);
// dbias tensor
int64_t dbias_dim[4] = {1, h, s_q, s_kv};
int64_t dbias_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
auto dBiasTensor =
tensor_create(tensorType, dBias_ID, dbias_dim, dbias_stride, false, false);
if (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) {
auto softmaxScaleTensor = tensor_create(CUDNN_DATA_FLOAT, S_CONST_ID, scale_dim,
scale_stride, false, true);
auto softmaxScaleReciprocalTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 401, scale_dim, scale_stride, true, false);
auto dbiasBeforeScaleTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 402,
dbias_dim, dbias_stride, true, false);
// Define the reduction descriptor
auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
.build();
// Create a reduction add node to compute the dbias
auto reductionAdd_op =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(dsAfterMaskTensor)
.setyDesc(dbiasBeforeScaleTensor)
.setreductionDesc(reductionAddDesc)
.build();
ops.push_back(std::move(reductionAdd_op));
// take the reciprocal of the scale
auto reciprocal_scale_desc =
pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_RECIPROCAL);
auto reciprocal_scale_op = unary_pw_op_create(
softmaxScaleTensor, softmaxScaleReciprocalTensor, reciprocal_scale_desc);
ops.push_back(std::move(reciprocal_scale_op));
// apply the scale
auto dBias_scale_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
auto dBias_scale_op =
binary_pw_op_create(dbiasBeforeScaleTensor, softmaxScaleReciprocalTensor,
dBiasTensor, dBias_scale_desc);
ops.push_back(std::move(dBias_scale_op));
}
// matmul to calculate dqTensor
// set padding value optionally to 0 for writing zeros to dqTensor (if not set, old
// behaviour)
auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build();
auto matmul_op2 =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(dsAfterMaskTensor)
.setbMatDesc(kTensor)
.setcMatDesc(dqTensor)
.setmOverrideDesc(seqlenQTensor)
.setkOverrideDesc(seqlenKTensor)
.setmatmulDesc(matmul_2_Desc)
.build();
ops.push_back(std::move(matmul_op2));
// reshape for transpose of ds
auto dsAfterMaskReshapeTensor = tensor_create(
tensorType, VIRTUAL_ID + 308, p_transpose_dim, p_transpose_stride, true, false);
auto reshape_2_op =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
.setxDesc(dsAfterMaskTensor)
.setyDesc(dsAfterMaskReshapeTensor)
.build();
ops.push_back(std::move(reshape_2_op));
// matmul to calculate dkTensor
// set padding value optionally to 0 for writing zeros to dktensor (if not set, old
// behaviour)
auto matmul_3_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build();
auto matmul_op3 =
cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(dsAfterMaskReshapeTensor)
.setbMatDesc(qTensor)
.setcMatDesc(dkTensor)
.setmOverrideDesc(seqlenKTensor)
.setkOverrideDesc(seqlenQTensor)
.setmatmulDesc(matmul_3_Desc)
.build();
ops.push_back(std::move(matmul_op3));
/////////////////////////////////////////////////////////////////
for (unsigned int i = 0; i < ops.size(); i++) {
all_ops.push_back(&ops[i]);
}
// Create an Operation Graph
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle)
.setOperationGraph(all_ops.size(), all_ops.data())
.build();
cudnn_frontend::EngineConfigList filtered_configs;
auto statuses = cudnn_frontend::get_heuristics_list<1>(
{"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true);
if (filtered_configs.size() == 0) {
cudnn_frontend::set_error_and_throw_exception(
nullptr, CUDNN_STATUS_NOT_SUPPORTED,
"run_mha_bprop: No config returned by the heuristics");
}
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(filtered_configs[0], opGraph.getTag())
.build();
cache.insert({descriptor, plan});
return plan;
};
auto plan = get_plan(fmha_bprop_cache, descriptor);
auto plan_workspace_size = plan.getWorkspaceSize();
// Exit to request upper level API to allocate memory if needed
if (workspace == nullptr) {
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
*workspace_size = plan_workspace_size + actual_seqlen_workspace_size;
return;
}
constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devActualSeqlenK = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, static_cast<const int32_t *>(devCuSeqlenQ),
static_cast<const int32_t *>(devCuSeqlenK), static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenK));
std::set<std::pair<uint64_t, void *>> data_ptrs;
// add all the data pointers to be used in the variant pack
data_ptrs.insert(std::pair<uint64_t, void *>(dQ_ID, devPtrdQ));
data_ptrs.insert(std::pair<uint64_t, void *>(dK_ID, devPtrdK));
data_ptrs.insert(std::pair<uint64_t, void *>(dV_ID, devPtrdV));
data_ptrs.insert(std::pair<uint64_t, void *>(Q_ID, devPtrQ));
data_ptrs.insert(std::pair<uint64_t, void *>(K_ID, devPtrK));
data_ptrs.insert(std::pair<uint64_t, void *>(V_ID, devPtrV));
data_ptrs.insert(std::pair<uint64_t, void *>(S_ID, devPtrS));
data_ptrs.insert(std::pair<uint64_t, void *>(dO_ID, devPtrdO));
data_ptrs.insert(std::pair<uint64_t, void *>(dS_ID, devPtrdS));
data_ptrs.insert(std::pair<uint64_t, void *>(Q_SEQLEN_ID, devActualSeqlenQ));
data_ptrs.insert(std::pair<uint64_t, void *>(K_SEQLEN_ID, devActualSeqlenK));
if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) {
data_ptrs.insert(std::pair<uint64_t, void *>(dBias_ID, devPtrdBias));
}
NVTE_CHECK(dropout_probability == 0.f,
"dropout probability > 0 in fused_attn_max_512 has not been implemented.");
float zeroVal = 0.0f;
float dropoutScale = 1.0f / (1.0f - dropout_probability);
data_ptrs.insert(std::pair<uint64_t, void *>(D_CONST_ID, &dropoutScale));
data_ptrs.insert(std::pair<uint64_t, void *>(S_CONST_ID, &scaling_factor));
data_ptrs.insert(std::pair<uint64_t, void *>(MASK_VAL_ID, &zeroVal));
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace)
.setDataPointers(data_ptrs)
.build();
NVTE_CHECK_CUDNN(
cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc()));
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
}
}
} // namespace fused_attn
using namespace transformer_engine::fused_attn;
void fused_attn_max_512_fwd_qkvpacked(
size_t batch, size_t max_seqlen, size_t num_head, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_Output_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
// Only is_training is verified
NVTE_CHECK(is_training, "is_training=False is not implemented in fused_attn_max_512.");
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED.");
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
const auto stride = num_head * head_dim;
void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void *devPtrBias = static_cast<void *>(input_Bias->data.dptr);
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
if (Aux_Output_Tensors->size == 0) {
Aux_Output_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen};
output_S->data.dtype = input_QKV->data.dtype;
} else if (Aux_Output_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
}
void *devCuSeqlen = cu_seqlens->data.dptr;
// TODO(rewang): dropout seed
// void* devPtrDropoutSeed = reinterpret_cast<void *>(
// reinterpret_cast<uint64_t*>(rng_state->data.dptr));
// void* devPtrDropoutOffset = reinterpret_cast<void *>(
// reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0;
// TODO(rewang): replace CPU seed
fused_attn_max_512_fwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ,
devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, devCuSeqlen,
devCuSeqlen, workspace->data.dptr, &workspace_size,
get_cudnn_dtype(QKV_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
}
}
void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_Output_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
// Only is_training is verified
NVTE_CHECK(is_training, "is_training=False is not implemented in fused_attn_max_512.");
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_Layout::NVTE_KV_INTERLEAVED.");
NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS ||
bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS,
"NVTE_PRE_SCALE_BIAS is not implemented in fused_attn_max_512.");
// Q shape is [b, s, h, d]
void *devPtrQ = input_Q->data.dptr;
// KV shape is [b, s, 2, h, d]
const auto stride = num_head * head_dim;
void *devPtrK = input_KV->data.dptr;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrK) + stride);
void *devPtrBias = input_Bias->data.dptr;
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
const DType q_type = input_Q->data.dtype;
const DType kv_type = input_KV->data.dtype;
NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV.");
if (Aux_Output_Tensors->size == 0) {
Aux_Output_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen};
output_S->data.dtype = q_type;
} else if (Aux_Output_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
}
void *devQCuSeqlen = q_cu_seqlens->data.dptr;
void *devKVCuSeqlen = kv_cu_seqlens->data.dptr;
// TODO(rewang): dropout seed
// void* devPtrDropoutSeed = reinterpret_cast<void *>(
// reinterpret_cast<uint64_t*>(rng_state->data.dptr));
// void* devPtrDropoutOffset = reinterpret_cast<void *>(
// reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
// TODO(rewang): replace CPU seed
fused_attn_max_512_fwd_impl(batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ,
devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, devQCuSeqlen,
devKVCuSeqlen, workspace->data.dptr, &workspace_size,
get_cudnn_dtype(q_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
}
}
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors,
Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_INTERLEAVED.");
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
auto stride = num_head * head_dim;
void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void *devPtrdO = input_dO->data.dptr;
// dQKV shape is [b, s, 3, h, d]
void *devPtrdQKV = output_dQKV->data.dptr;
void *devPtrdQ = devPtrdQKV;
void *devPtrdK = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + stride);
void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + 2 * stride);
void *devPtrdBias = output_dBias->data.dptr;
NVTE_CHECK(Aux_CTX_Tensors->size == 1);
void *devPtrS = nullptr;
if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
}
// devPtrdS reuses the memory of devPtrS
void *devPtrdS = devPtrS;
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
const auto qkv_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn_max_512_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, attn_scale,
p_dropout, qkv_layout, mask_type, bias_type, devPtrQ, devPtrK,
devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdS,
devPtrdBias, devPtrCuSeqlens, devPtrCuSeqlens, workspace->data.dptr,
&workspace_size, get_cudnn_dtype(qkv_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
}
}
void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors,
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED,
"qkv_layout must be NVTE_KV_INTERLEAVED.");
// Q shape is [b, s, h, d]
// KV shape is [b, s, 2, h, d]
auto stride = num_head * head_dim;
void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_KV->data.dptr;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrK) + stride);
void *devPtrdO = input_dO->data.dptr;
// dQ shape is [b, s, h, d]
// dKV shape is [b, s, 2, h, d]
void *devPtrdQ = output_dQ->data.dptr;
void *devPtrdK = output_dKV->data.dptr;
void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdK) + stride);
void *devPtrdBias = output_dBias->data.dptr;
NVTE_CHECK(Aux_CTX_Tensors->size == 1);
void *devPtrS = nullptr;
if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
}
// devPtrdS reuses the memory of devPtrS
void *devPtrdS = devPtrS;
void *devPtrQCuSeqlens = q_cu_seqlens->data.dptr;
void *devPtrKVCuSeqlens = kv_cu_seqlens->data.dptr;
const auto q_type = input_Q->data.dtype;
const auto kv_type = input_KV->data.dtype;
NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV.");
size_t workspace_size = 0;
fused_attn_max_512_bwd_impl(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout,
mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV,
devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data.dptr,
&workspace_size, get_cudnn_dtype(q_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
}
}
} // namespace transformer_engine
#endif // CUDNN_VERSION >= 8901
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file fused_attn_max_512.h
* \brief Functions for fused attention with seqlen <= 512
*/
#ifndef TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_MAX_512_H_
#define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_MAX_512_H_
#include "transformer_engine/fused_attn.h"
#include <cudnn.h>
#include "common/common.h"
namespace transformer_engine {
#if (CUDNN_VERSION >= 8901)
void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_size, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_Output_Tensors,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_Output_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors,
Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors,
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8901
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_MAX_512_H_
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
************************************************************************/ ************************************************************************/
#include "transformer_engine/fused_attn.h" #include "transformer_engine/fused_attn.h"
#include "../common.h" #include "../common.h"
#include "utils.h" #include "utils.h"
#include "fused_attn_fp8.h" #include "fused_attn_fp8.h"
...@@ -70,84 +71,6 @@ std::unordered_map<std::string, int> tensor_name_to_uid = { ...@@ -70,84 +71,6 @@ std::unordered_map<std::string, int> tensor_name_to_uid = {
{"VIRTUAL", 80} {"VIRTUAL", 80}
}; };
bool allowAllConfig(cudnnBackendDescriptor_t engine_config) {
(void)engine_config;
return false;
}
static cudnn_frontend::Tensor tensor_create(
cudnnDataType_t type, int64_t id,
int64_t const * dim, int64_t const * stride,
bool is_virtual, bool is_value) {
int nbDims = 4;
auto tensor_created = cudnn_frontend::TensorBuilder()
.setDim(nbDims, dim)
.setStride(nbDims, stride)
.setId(id)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(type)
.setVirtual(is_virtual)
.setByValue(is_value)
.build();
return tensor_created;
}
static cudnn_frontend::Tensor tensor_create_with_offset(
cudnnDataType_t type, int64_t id,
int64_t const * dim, int64_t const * stride,
bool is_virtual, bool is_value,
std::shared_ptr<cudnn_frontend::Tensor> raggedOffset) {
int nbDims = 4;
auto tensor_created = cudnn_frontend::TensorBuilder()
.setDim(nbDims, dim)
.setStride(nbDims, stride)
.setId(id)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(type)
.setVirtual(is_virtual)
.setByValue(is_value)
.setRaggedOffset(raggedOffset)
.build();
return tensor_created;
}
static cudnn_frontend::PointWiseDesc pw_desc_create(
cudnnDataType_t type, cudnnPointwiseMode_t mode) {
auto pw_desc_created = cudnn_frontend::PointWiseDescBuilder()
.setMode(mode)
.setComputeType(type)
.build();
return pw_desc_created;
}
static cudnn_frontend::Operation unary_pw_op_create(
cudnn_frontend::Tensor const &xDesc,
cudnn_frontend::Tensor const &yDesc,
cudnn_frontend::PointWiseDesc const &pwDesc) {
auto pw_op_created = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(xDesc)
.setyDesc(yDesc)
.setpwDesc(pwDesc)
.build();
return pw_op_created;
}
static cudnn_frontend::Operation binary_pw_op_create(
cudnn_frontend::Tensor const &xDesc,
cudnn_frontend::Tensor const &bDesc,
cudnn_frontend::Tensor const &yDesc,
cudnn_frontend::PointWiseDesc const &pwDesc) {
auto pw_op_created = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(xDesc)
.setbDesc(bDesc)
.setyDesc(yDesc)
.setpwDesc(pwDesc)
.build();
return pw_op_created;
}
static cudnn_frontend::Tensor createAmax( static cudnn_frontend::Tensor createAmax(
const std::string& amax_tensor_name, const std::string& amax_tensor_name,
const cudnn_frontend::Tensor& prevBlockOutputTensor, const cudnn_frontend::Tensor& prevBlockOutputTensor,
...@@ -1089,7 +1012,8 @@ void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, ...@@ -1089,7 +1012,8 @@ void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
FADescriptor descriptor{ FADescriptor descriptor{
b, h, s_q, s_kv, d, b, h, s_q, s_kv, d,
attnScale, isTraining, dropoutProbability, layout, tensorType}; attnScale, isTraining, dropoutProbability, layout,
NVTE_Bias_Type::NVTE_NO_BIAS, NVTE_Mask_Type::NVTE_PADDING_MASK, tensorType};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>; using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static CacheType fa_fprop_cache; static CacheType fa_fprop_cache;
...@@ -1404,7 +1328,8 @@ void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, ...@@ -1404,7 +1328,8 @@ void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
FADescriptor descriptor{ FADescriptor descriptor{
b, h, s_q, s_kv, d, b, h, s_q, s_kv, d,
attnScale, false, dropoutProbability, layout, tensorType}; attnScale, false, dropoutProbability, layout,
NVTE_Bias_Type::NVTE_NO_BIAS, NVTE_Mask_Type::NVTE_PADDING_MASK, tensorType};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>; using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static CacheType fa_bprop_cache; static CacheType fa_bprop_cache;
......
...@@ -51,13 +51,13 @@ void generateMatrixStrides( ...@@ -51,13 +51,13 @@ void generateMatrixStrides(
strideA[head_dim_idx] = d; strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 3 * h * d; strideA[batch_dim_idx] = s_kv * 3 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) { } else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) {
strideA[seqlen_transpose_dim_idx] = 2 * h * d; strideA[seqlen_dim_idx] = 2 * h * d;
strideA[hidden_transpose_dim_idx] = 1; strideA[hidden_dim_idx] = 1;
strideA[head_dim_idx] = d; strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d; strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else { } else {
strideA[seqlen_transpose_dim_idx] = h * d; strideA[seqlen_dim_idx] = h * d;
strideA[hidden_transpose_dim_idx] = 1; strideA[hidden_dim_idx] = 1;
strideA[head_dim_idx] = d; strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * h * d; strideA[batch_dim_idx] = s_kv * h * d;
} }
...@@ -131,6 +131,99 @@ void generateMatrixStrides( ...@@ -131,6 +131,99 @@ void generateMatrixStrides(
} }
} }
bool allowAllConfig(cudnnBackendDescriptor_t engine_config) {
(void)engine_config;
return false;
}
cudnn_frontend::Tensor tensor_create(
cudnnDataType_t type, int64_t id,
int64_t const * dim, int64_t const * stride,
bool is_virtual, bool is_value) {
int nbDims = 4;
auto tensor_created = cudnn_frontend::TensorBuilder()
.setDim(nbDims, dim)
.setStride(nbDims, stride)
.setId(id)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(type)
.setVirtual(is_virtual)
.setByValue(is_value)
.build();
return tensor_created;
}
cudnn_frontend::Tensor tensor_create_with_offset(
cudnnDataType_t type, int64_t id,
int64_t const * dim, int64_t const * stride,
bool is_virtual, bool is_value,
std::shared_ptr<cudnn_frontend::Tensor> raggedOffset) {
int nbDims = 4;
auto tensor_created = cudnn_frontend::TensorBuilder()
.setDim(nbDims, dim)
.setStride(nbDims, stride)
.setId(id)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(type)
.setVirtual(is_virtual)
.setByValue(is_value)
.setRaggedOffset(raggedOffset)
.build();
return tensor_created;
}
cudnn_frontend::PointWiseDesc pw_desc_create(
cudnnDataType_t type, cudnnPointwiseMode_t mode) {
auto pw_desc_created = cudnn_frontend::PointWiseDescBuilder()
.setMode(mode)
.setComputeType(type)
.build();
return pw_desc_created;
}
cudnn_frontend::Operation unary_pw_op_create(
cudnn_frontend::Tensor const &xDesc,
cudnn_frontend::Tensor const &yDesc,
cudnn_frontend::PointWiseDesc const &pwDesc) {
auto pw_op_created = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(xDesc)
.setyDesc(yDesc)
.setpwDesc(pwDesc)
.build();
return pw_op_created;
}
cudnn_frontend::Operation binary_pw_op_create(
cudnn_frontend::Tensor const &xDesc,
cudnn_frontend::Tensor const &bDesc,
cudnn_frontend::Tensor const &yDesc,
cudnn_frontend::PointWiseDesc const &pwDesc) {
auto pw_op_created = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(xDesc)
.setbDesc(bDesc)
.setyDesc(yDesc)
.setpwDesc(pwDesc)
.build();
return pw_op_created;
}
cudnn_frontend::Operation ternary_pw_op_create(
cudnn_frontend::Tensor const &xDesc, cudnn_frontend::Tensor const &bDesc,
cudnn_frontend::Tensor const &tDesc, cudnn_frontend::Tensor const &yDesc,
cudnn_frontend::PointWiseDesc const &pwDesc) {
auto pw_op_created = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(xDesc)
.setbDesc(bDesc)
.settDesc(tDesc)
.setyDesc(yDesc)
.setpwDesc(pwDesc)
.build();
return pw_op_created;
}
// convert cu_seqlens_q to qkv/o_ragged_offset and actual_seqlens_q // convert cu_seqlens_q to qkv/o_ragged_offset and actual_seqlens_q
__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, __global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d,
int32_t *cu_seqlens_q, int32_t *actual_seqlens_q, int32_t *cu_seqlens_q, int32_t *actual_seqlens_q,
...@@ -144,6 +237,19 @@ __global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, ...@@ -144,6 +237,19 @@ __global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d,
o_ragged_offset[tid] = cu_seqlens_q[tid] * h * d; o_ragged_offset[tid] = cu_seqlens_q[tid] * h * d;
} }
} }
// convert cu_seqlens to actual_seqlens
__global__ void cu_seqlens_to_actual_seqlens(size_t b,
int32_t const * const q_cu_seqlens,
int32_t const * const kv_cu_seqlens,
int32_t *q_seqlens, int32_t *kv_seqlens) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < b) {
q_seqlens[tid] = q_cu_seqlens[tid + 1] - q_cu_seqlens[tid];
kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid];
}
}
} // namespace fused_attn } // namespace fused_attn
// get cuDNN data type // get cuDNN data type
......
...@@ -7,9 +7,15 @@ ...@@ -7,9 +7,15 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_ #ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_ #define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#include "transformer_engine/fused_attn.h"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
#include <cudnn.h>
#include <cudnn_frontend.h> #include <cudnn_frontend.h>
#include <cstdint>
#include <mutex>
namespace transformer_engine { namespace transformer_engine {
namespace fused_attn { namespace fused_attn {
...@@ -31,6 +37,36 @@ void generateMatrixStrides( ...@@ -31,6 +37,36 @@ void generateMatrixStrides(
int64_t d, int64_t* strideA, int64_t d, int64_t* strideA,
NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix); NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix);
bool allowAllConfig(cudnnBackendDescriptor_t engine_config);
cudnn_frontend::Tensor tensor_create(cudnnDataType_t type, int64_t id,
int64_t const *dim,
int64_t const *stride,
bool is_virtual, bool is_value);
cudnn_frontend::Tensor tensor_create_with_offset(
cudnnDataType_t type, int64_t id,
int64_t const * dim, int64_t const * stride,
bool is_virtual, bool is_value,
std::shared_ptr<cudnn_frontend::Tensor> raggedOffset);
cudnn_frontend::PointWiseDesc pw_desc_create(cudnnDataType_t type,
cudnnPointwiseMode_t mode);
cudnn_frontend::Operation unary_pw_op_create(
cudnn_frontend::Tensor const &xDesc, cudnn_frontend::Tensor const &yDesc,
cudnn_frontend::PointWiseDesc const &pwDesc);
cudnn_frontend::Operation binary_pw_op_create(
cudnn_frontend::Tensor const &xDesc, cudnn_frontend::Tensor const &bDesc,
cudnn_frontend::Tensor const &yDesc,
cudnn_frontend::PointWiseDesc const &pwDesc);
cudnn_frontend::Operation ternary_pw_op_create(
cudnn_frontend::Tensor const &xDesc, cudnn_frontend::Tensor const &bDesc,
cudnn_frontend::Tensor const &tDesc, cudnn_frontend::Tensor const &yDesc,
cudnn_frontend::PointWiseDesc const &pwDesc);
struct FADescriptor { struct FADescriptor {
std::int64_t b; std::int64_t b;
std::int64_t h; std::int64_t h;
...@@ -41,15 +77,19 @@ struct FADescriptor { ...@@ -41,15 +77,19 @@ struct FADescriptor {
bool isTraining; bool isTraining;
float dropoutProbability; float dropoutProbability;
NVTE_QKV_Layout layout; NVTE_QKV_Layout layout;
NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type;
cudnnDataType_t tensor_type; cudnnDataType_t tensor_type;
bool operator<(const FADescriptor &rhs) const { bool operator<(const FADescriptor &rhs) const {
return std::tie(b, h, s_q, s_kv, d, return std::tie(b, h, s_q, s_kv, d,
attnScale, isTraining, dropoutProbability, attnScale, isTraining, dropoutProbability,
layout, tensor_type) < std::tie( layout, mask_type, bias_type, tensor_type)
rhs.b, rhs.h, rhs.s_q, rhs.s_kv, rhs.d, < std::tie(
rhs.attnScale, rhs.isTraining, rhs.b, rhs.h, rhs.s_q, rhs.s_kv, rhs.d,
rhs.dropoutProbability, rhs.layout, rhs.tensor_type); rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.bias_type, rhs.tensor_type);
} }
}; };
...@@ -57,6 +97,11 @@ __global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, ...@@ -57,6 +97,11 @@ __global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d,
int32_t *cu_seqlens_q, int32_t *actual_seqlens_q, int32_t *cu_seqlens_q, int32_t *actual_seqlens_q,
int32_t *qkv_ragged_offset, int32_t *o_ragged_offset); int32_t *qkv_ragged_offset, int32_t *o_ragged_offset);
__global__ void cu_seqlens_to_actual_seqlens(size_t b,
int32_t const * const q_cu_seqlens,
int32_t const * const kv_cu_seqlens,
int32_t *q_seqlens, int32_t *kv_seqlens);
} // namespace fused_attn } // namespace fused_attn
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
......
...@@ -78,8 +78,9 @@ enum NVTE_Mask_Type { ...@@ -78,8 +78,9 @@ enum NVTE_Mask_Type {
* - O = D * V.T * - O = D * V.T
* *
* Support Matrix: * Support Matrix:
* | precision | qkv layout | bias | mask | sequence length | head_dim | * | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
* | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | <= 512 | 64 | * | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
* | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |
* *
* *
* \param[in] QKV The QKV tensor in packed format, * \param[in] QKV The QKV tensor in packed format,
...@@ -119,8 +120,9 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -119,8 +120,9 @@ void nvte_fused_attn_fwd_qkvpacked(
/*! \brief Compute the backward of the dot product attention with packed QKV input. /*! \brief Compute the backward of the dot product attention with packed QKV input.
* *
* Support Matrix: * Support Matrix:
* | precision | qkv layout | bias | mask | sequence length | head_dim | * | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
* | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | <= 512 | 64 | * | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
* | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |
* *
* *
* \param[in] QKV The QKV tensor in packed format, * \param[in] QKV The QKV tensor in packed format,
...@@ -168,6 +170,11 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -168,6 +170,11 @@ void nvte_fused_attn_bwd_qkvpacked(
* - D = Dropout(S) * - D = Dropout(S)
* - O = D * V.T * - O = D * V.T
* *
* Support Matrix:
* | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
* | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |
*
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim]. * \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
* \param[in] Bias The Bias tensor. * \param[in] Bias The Bias tensor.
...@@ -208,6 +215,11 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -208,6 +215,11 @@ void nvte_fused_attn_fwd_kvpacked(
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input. /*! \brief Compute the backward of the dot product attention with packed KV input.
*
* Support Matrix:
* | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
* | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |
*
* *
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim]. * \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
......
...@@ -19,6 +19,8 @@ from jax.interpreters.mlir import ir, dtype_to_ir_type ...@@ -19,6 +19,8 @@ from jax.interpreters.mlir import ir, dtype_to_ir_type
import transformer_engine_jax import transformer_engine_jax
from transformer_engine_jax import DType as TEDType from transformer_engine_jax import DType as TEDType
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
for _name, _value in transformer_engine_jax.registrations().items(): for _name, _value in transformer_engine_jax.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA") xla_client.register_custom_call_target(_name, _value, platform="CUDA")
...@@ -1973,3 +1975,453 @@ def scaled_upper_triang_masked_softmax_bwd(grad_outputs: jnp.ndarray, softmax_ou ...@@ -1973,3 +1975,453 @@ def scaled_upper_triang_masked_softmax_bwd(grad_outputs: jnp.ndarray, softmax_ou
return _scaled_upper_triang_masked_softmax_bwd_p.bind(grad_outputs, return _scaled_upper_triang_masked_softmax_bwd_p.bind(grad_outputs,
softmax_outputs, softmax_outputs,
scale_factor=scale_factor) scale_factor=scale_factor)
class SelfFusedAttnMax512FwdPrimitive(BasePrimitive):
"""
Self Fused Attention Max Seqlen 512 Forward Primitive
"""
name = "te_self_fused_attn_max_512_forward"
multiple_results = True
@staticmethod
def abstract(
qkv,
bias,
cu_seqlen, # pylint: disable=unused-argument
rng_state, # pylint: disable=unused-argument
*,
attn_bias_type, # pylint: disable=unused-argument
attn_mask_type, # pylint: disable=unused-argument
scaling_factor, # pylint: disable=unused-argument
dropout_probability, # pylint: disable=unused-argument
is_training # pylint: disable=unused-argument
):
"""
Self fused attention max seqlen 512 fwd abstract
"""
qkv_dtype = dtypes.canonicalize_dtype(qkv.dtype)
batch, max_seqlen, nqkv, num_head, head_dim = qkv.shape
assert nqkv == 3
assert qkv.dtype == bias.dtype
output_shape = (batch, max_seqlen, num_head, head_dim)
output_dtype = qkv_dtype
softmax_aux_shape = (batch, num_head, max_seqlen, max_seqlen)
softmax_dtype = qkv_dtype
return (
ShapedArray(output_shape, output_dtype, named_shape=qkv.named_shape), # output
ShapedArray(softmax_aux_shape, softmax_dtype,
named_shape=qkv.named_shape), # softmax_aux
)
@staticmethod
def lowering(ctx, qkv, bias, cu_seqlen, rng_state, *, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
"""
Self fused attention max seqlen 512 fwd lowering rules
"""
qkv_aval, _, _, _ = ctx.avals_in
ir_qkv_type = ir.RankedTensorType(qkv.type)
ir_qkv_shape = ir_qkv_type.shape
ir_bias_type = ir.RankedTensorType(bias.type)
ir_bias_shape = ir_bias_type.shape
ir_cu_seqlen_type = ir.RankedTensorType(cu_seqlen.type)
ir_cu_seqlen_shape = ir_cu_seqlen_type.shape
ir_rng_state_type = ir.RankedTensorType(rng_state.type)
ir_rng_state_shape = ir_rng_state_type.shape
batch, max_seqlen, nqkv, num_head, head_dim = ir_qkv_shape
assert nqkv == 3
output_shape = (batch, max_seqlen, num_head, head_dim)
softmax_aux_shape = (batch, num_head, max_seqlen, max_seqlen)
out_types = [
ir.RankedTensorType.get(output_shape, ir_qkv_type.element_type),
ir.RankedTensorType.get(softmax_aux_shape, ir_qkv_type.element_type)
]
operands = [qkv, bias, cu_seqlen, rng_state]
operand_shapes = [ir_qkv_shape, ir_bias_shape, ir_cu_seqlen_shape, ir_rng_state_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
out = custom_caller(SelfFusedAttnMax512FwdPrimitive.name,
args,
opaque,
has_side_effect=False)
return out
_self_fused_attn_max_512_fwd_p = register_primitive(SelfFusedAttnMax512FwdPrimitive)
def self_fused_attn_max_512_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, cu_seqlen: jnp.ndarray,
rng_state: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
Wrapper for TE self fused attention max seqlen 512 fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
"""
# Jax can't bind None, create a dummy tensor for None
if rng_state is None:
rng_state = jnp.zeros(2, dtype=jnp.int32)
if bias is None:
assert attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS
bias = jnp.zeros(0, dtype=qkv.dtype)
return _self_fused_attn_max_512_fwd_p.bind(qkv,
bias,
cu_seqlen,
rng_state,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
class SelfFusedAttnMax512BwdPrimitive(BasePrimitive):
"""
Self Fused Attention Max Seqlen 512 Backward Primitive
"""
name = "te_self_fused_attn_max_512_backward"
multiple_results = True
@staticmethod
def abstract(
qkv,
softmax_aux,
doutput,
cu_seqlen, # pylint: disable=unused-argument
*,
attn_bias_type, # pylint: disable=unused-argument
attn_mask_type, # pylint: disable=unused-argument
scaling_factor, # pylint: disable=unused-argument
dropout_probability, # pylint: disable=unused-argument
is_training # pylint: disable=unused-argument
):
"""
Self fused attention bwd abstract
"""
qkv_dtype = dtypes.canonicalize_dtype(qkv.dtype)
assert qkv.dtype == softmax_aux.dtype == doutput.dtype
_, seqlen, _, num_head, _ = qkv.shape
bias_shape = (1, num_head, seqlen, seqlen)
bias_dtype = qkv_dtype
return (
ShapedArray(qkv.shape, qkv_dtype, named_shape=qkv.named_shape), # dqkv
ShapedArray(bias_shape, bias_dtype, named_shape=qkv.named_shape))
@staticmethod
def lowering(ctx, qkv, softmax_aux, doutput, cu_seqlen, *, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
"""
Self fused attention max seqlen 512 bwd lowering rules
"""
qkv_aval, _, _, _ = ctx.avals_in
ir_qkv_type = ir.RankedTensorType(qkv.type)
ir_qkv_shape = ir_qkv_type.shape
ir_softmax_aux_type = ir.RankedTensorType(softmax_aux.type)
ir_softmax_aux_shape = ir_softmax_aux_type.shape
ir_doutput_type = ir.RankedTensorType(doutput.type)
ir_doutput_shape = ir_doutput_type.shape
ir_cu_seqlen_type = ir.RankedTensorType(cu_seqlen.type)
ir_cu_seqlen_shape = ir_cu_seqlen_type.shape
batch, max_seqlen, num_head, head_dim = ir_doutput_shape
dbias_shape = (1, num_head, max_seqlen, max_seqlen)
dbias_dtype = ir_qkv_type.element_type
out_types = [
ir.RankedTensorType.get(ir_qkv_shape, ir_qkv_type.element_type),
ir.RankedTensorType.get(dbias_shape, dbias_dtype)
]
operands = [qkv, softmax_aux, doutput, cu_seqlen]
operand_shapes = [ir_qkv_shape, ir_softmax_aux_shape, ir_doutput_shape, ir_cu_seqlen_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
out = custom_caller(SelfFusedAttnMax512BwdPrimitive.name,
args,
opaque,
has_side_effect=False)
return out
_self_fused_attn_max_512_bwd_p = register_primitive(SelfFusedAttnMax512BwdPrimitive)
def self_fused_attn_max_512_bwd(qkv: jnp.ndarray, softmax_aux: jnp.ndarray, doutput: jnp.ndarray,
cu_seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
Wrapper for TE self fused attention max seqlen 512 bwd
Return the gradients of self fused attention with packed qkv input
"""
return _self_fused_attn_max_512_bwd_p.bind(qkv,
softmax_aux,
doutput,
cu_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
class CrossFusedAttnMax512FwdPrimitive(BasePrimitive):
"""
Cross Fused Attention Forward Max Seqlen 512 Primitive
"""
name = "te_cross_fused_attn_max_512_forward"
multiple_results = True
@staticmethod
def abstract(
q,
kv,
q_cu_seqlen,
kv_cu_seqlen,
rng_state, # pylint: disable=unused-argument
*,
attn_bias_type, # pylint: disable=unused-argument
attn_mask_type, # pylint: disable=unused-argument
scaling_factor, # pylint: disable=unused-argument
dropout_probability, # pylint: disable=unused-argument
is_training # pylint: disable=unused-argument
):
"""
Cross fused attention max seqlen 512 fwd abstract
"""
q_dtype = dtypes.canonicalize_dtype(q.dtype)
batch_q, q_max_seqlen, num_head_q, head_dim_q = q.shape
kv_dtype = dtypes.canonicalize_dtype(kv.dtype)
batch_kv, kv_max_seqlen, nkv, num_head_kv, head_dim_kv = kv.shape
assert q_dtype == kv_dtype
assert batch_q == batch_kv
assert num_head_q == num_head_kv
assert head_dim_q == head_dim_kv
assert nkv == 2
assert q_cu_seqlen.dtype == kv_cu_seqlen.dtype
output_shape = q.shape
output_dtype = q_dtype
softmax_aux_shape = (batch_q, num_head_q, q_max_seqlen, kv_max_seqlen)
softmax_aux_dtype = q_dtype
return (
ShapedArray(output_shape, output_dtype, named_shape=q.named_shape), # output
ShapedArray(softmax_aux_shape, softmax_aux_dtype,
named_shape=q.named_shape), # softmax_aux
)
@staticmethod
def lowering(ctx, q, kv, q_cu_seqlen, kv_cu_seqlen, rng_state, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Cross fused attention max seqlen 512 fwd lowering rules
"""
q_aval, kv_aval, _, _, _ = ctx.avals_in
assert q_aval.dtype == kv_aval.dtype
ir_q_type = ir.RankedTensorType(q.type)
ir_q_shape = ir_q_type.shape
ir_kv_type = ir.RankedTensorType(kv.type)
ir_kv_shape = ir_kv_type.shape
ir_q_cu_seqlen_shape = ir.RankedTensorType(q_cu_seqlen.type).shape
ir_kv_cu_seqlen_shape = ir.RankedTensorType(kv_cu_seqlen.type).shape
ir_rng_state_type = ir.RankedTensorType(rng_state.type)
ir_rng_state_shape = ir_rng_state_type.shape
batch, q_max_seqlen, num_head, head_dim = ir_q_shape
kv_max_seqlen = ir_kv_shape[1]
output_shape = (batch, q_max_seqlen, num_head, head_dim)
softmax_aux_shape = (batch, num_head, q_max_seqlen, kv_max_seqlen)
out_types = [
ir.RankedTensorType.get(output_shape, ir_q_type.element_type),
ir.RankedTensorType.get(softmax_aux_shape, ir_q_type.element_type)
]
operands = [q, kv, q_cu_seqlen, kv_cu_seqlen, rng_state]
operand_shapes = [
ir_q_shape, ir_kv_shape, ir_q_cu_seqlen_shape, ir_kv_cu_seqlen_shape, ir_rng_state_shape
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
out = custom_caller(CrossFusedAttnMax512FwdPrimitive.name,
args,
opaque,
has_side_effect=False)
return out
_cross_fused_attn_max_512_fwd_p = register_primitive(CrossFusedAttnMax512FwdPrimitive)
def cross_fused_attn_max_512_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_cu_seqlen: jnp.ndarray,
kv_cu_seqlen: jnp.ndarray, rng_state: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float,
is_training: bool):
"""
Wrapper for TE cross fused attention max seqlen 512 fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
"""
# Jax can't bind None, create a dummy tensor for None
if rng_state is None:
rng_state = jnp.zeros(2, dtype=jnp.int32)
return _cross_fused_attn_max_512_fwd_p.bind(q,
kv,
q_cu_seqlen,
kv_cu_seqlen,
rng_state,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
class CrossFusedAttnMax512BwdPrimitive(BasePrimitive):
"""
Cross Fused Attention Max Seqlen 512 Backward Primitive
"""
name = "te_cross_fused_attn_max_512_backward"
multiple_results = True
@staticmethod
def abstract(
q,
kv,
softmax_aux,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
*,
attn_bias_type, # pylint: disable=unused-argument
attn_mask_type, # pylint: disable=unused-argument
scaling_factor, # pylint: disable=unused-argument
dropout_probability, # pylint: disable=unused-argument
is_training # pylint: disable=unused-argument
):
"""
Cross fused attention max seqlen 512 bwd abstract
"""
q_dtype = dtypes.canonicalize_dtype(q.dtype)
kv_dtype = dtypes.canonicalize_dtype(kv.dtype)
softmax_aux_dtype = dtypes.canonicalize_dtype(softmax_aux.dtype)
doutput_dtype = dtypes.canonicalize_dtype(doutput.dtype)
assert q_dtype == kv_dtype == softmax_aux_dtype == doutput_dtype
assert q_cu_seqlen.dtype == kv_cu_seqlen.dtype
return (
ShapedArray(q.shape, q_dtype, named_shape=q.named_shape), # dq
ShapedArray(kv.shape, kv_dtype, named_shape=kv.named_shape), # dkv
)
@staticmethod
def lowering(ctx, q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Cross fused attention max seqlen 512 bwd lowering rules
"""
q_aval, _, _, _, _, _ = ctx.avals_in
ir_q_type = ir.RankedTensorType(q.type)
ir_q_shape = ir_q_type.shape
ir_kv_type = ir.RankedTensorType(kv.type)
ir_kv_shape = ir_kv_type.shape
ir_softmax_aux_type = ir.RankedTensorType(softmax_aux.type)
ir_softmax_aux_shape = ir_softmax_aux_type.shape
ir_doutput_shape = ir.RankedTensorType(doutput.type).shape
ir_q_cu_seqlen_shape = ir.RankedTensorType(q_cu_seqlen.type).shape
ir_kv_cu_seqlen_shape = ir.RankedTensorType(kv_cu_seqlen.type).shape
batch, q_max_seqlen, num_head, head_dim = ir_doutput_shape
kv_max_seqlen = ir_kv_shape[1]
out_types = [
ir.RankedTensorType.get(ir_q_shape, ir_q_type.element_type),
ir.RankedTensorType.get(ir_kv_shape, ir_kv_type.element_type),
]
operands = [q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen]
operand_shapes = [
ir_q_shape, ir_kv_shape, ir_softmax_aux_shape, ir_doutput_shape, ir_q_cu_seqlen_shape,
ir_kv_cu_seqlen_shape
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
out = custom_caller(CrossFusedAttnMax512BwdPrimitive.name,
args,
opaque,
has_side_effect=False)
return out
_cross_fused_attn_max_512_bwd_p = register_primitive(CrossFusedAttnMax512BwdPrimitive)
def cross_fused_attn_max_512_bwd(q: jnp.ndarray, kv: jnp.ndarray, softmax_aux: jnp.ndarray,
doutput: jnp.ndarray, q_cu_seqlen: jnp.ndarray,
kv_cu_seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
Wrapper for TE cross fused attention max seqlen 512 bwd
Return the gradients of cross fused attention with packed kv input
"""
return _cross_fused_attn_max_512_bwd_p.bind(q,
kv,
softmax_aux,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include "common/include/transformer_engine/fused_attn.h"
#include "common/include/transformer_engine/transformer_engine.h" #include "common/include/transformer_engine/transformer_engine.h"
#include "jax/csrc/modules.h" #include "jax/csrc/modules.h"
#include "jax/csrc/utils.h" #include "jax/csrc/utils.h"
...@@ -43,6 +44,11 @@ pybind11::dict Registrations() { ...@@ -43,6 +44,11 @@ pybind11::dict Registrations() {
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward); EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward);
dict["te_scaled_upper_triang_masked_softmax_backward"] = dict["te_scaled_upper_triang_masked_softmax_backward"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward); EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward);
dict["te_self_fused_attn_max_512_forward"] = EncapsulateFunction(SelfFusedAttnMax512Forward);
dict["te_self_fused_attn_max_512_backward"] = EncapsulateFunction(SelfFusedAttnMax512Backward);
dict["te_cross_fused_attn_max_512_forward"] = EncapsulateFunction(CrossFusedAttnMax512Forward);
dict["te_cross_fused_attn_max_512_backward"] =
EncapsulateFunction(CrossFusedAttnMax512Backward);
return dict; return dict;
} }
...@@ -52,15 +58,28 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -52,15 +58,28 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("pack_gemm_descriptor", &PackCustomCallGemmDescriptor); m.def("pack_gemm_descriptor", &PackCustomCallGemmDescriptor);
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor); m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
m.def("is_fused_attn_kernel_available", &IsFusedAttnKernelAvailable);
pybind11::enum_<DType>(m, "DType", pybind11::module_local()) pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte) .value("kByte", DType::kByte)
.value("kInt32", DType::kInt32) .value("kInt32", DType::kInt32)
.value("KInt64", DType::kInt64)
.value("kFloat32", DType::kFloat32) .value("kFloat32", DType::kFloat32)
.value("kFloat16", DType::kFloat16) .value("kFloat16", DType::kFloat16)
.value("kBFloat16", DType::kBFloat16) .value("kBFloat16", DType::kBFloat16)
.value("kFloat8E4M3", DType::kFloat8E4M3) .value("kFloat8E4M3", DType::kFloat8E4M3)
.value("kFloat8E5M2", DType::kFloat8E5M2); .value("kFloat8E5M2", DType::kFloat8E5M2);
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local())
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS)
.value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
pybind11::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type", pybind11::module_local())
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK);
} }
} // namespace jax } // namespace jax
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <cublasLt.h> #include <cublasLt.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cudnn.h>
#include <functional> #include <functional>
#include <numeric> #include <numeric>
...@@ -19,6 +20,7 @@ ...@@ -19,6 +20,7 @@
#include "common/common.h" #include "common/common.h"
#include "transformer_engine/activation.h" #include "transformer_engine/activation.h"
#include "transformer_engine/cast.h" #include "transformer_engine/cast.h"
#include "transformer_engine/fused_attn.h"
#include "transformer_engine/gemm.h" #include "transformer_engine/gemm.h"
#include "transformer_engine/layer_norm.h" #include "transformer_engine/layer_norm.h"
#include "transformer_engine/rmsnorm.h" #include "transformer_engine/rmsnorm.h"
...@@ -78,6 +80,25 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, ...@@ -78,6 +80,25 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch,
SoftmaxDescriptor{batch, pad_batch, heads, q_seqlen, k_seqlen, dtype, scale_factor}); SoftmaxDescriptor{batch, pad_batch, heads, q_seqlen, k_seqlen, dtype, scale_factor});
} }
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{batch, num_head, q_max_seqlen, kv_max_seqlen,
head_dim, scaling_factor, dropout_probability,
bias_type, mask_type, dtype, is_training});
}
bool IsFusedAttnKernelAvailable() {
#if (CUDNN_VERSION >= 8901)
auto major = cudaDevicePropertiesManager::Instance().GetMajor();
// Fused attention requires at least Ampere
return major >= 8;
#else
return false;
#endif
}
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream, void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
void *output) { void *output) {
auto input_shape = std::vector<size_t>{rows, cols}; auto input_shape = std::vector<size_t>{rows, cols};
...@@ -718,5 +739,333 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, ...@@ -718,5 +739,333 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
grad_output_tensor.data(), softmax_output_tensor.data(), dgrad_tensor.data(), grad_output_tensor.data(), softmax_output_tensor.data(), dgrad_tensor.data(),
desc.scale_factor, stream); desc.scale_factor, stream);
} }
void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input
void *qkv = buffers[0];
void *bias = buffers[1];
void *cu_seqlens = buffers[2];
void *rng_state = buffers[3];
// output
void *output = buffers[4];
void *softmax_aux = buffers[5];
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
// FP16/BF16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor =
TensorWrapper(output, std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}, dtype);
auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{1}, DType::kInt64);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type,
descriptor.mask_type, query_workspace_tensor.data(), stream);
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux;
size_t workspace_size =
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type,
descriptor.mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input
void *qkv = buffers[0];
void *softmax_aux = buffers[1];
void *doutput = buffers[2];
void *cu_seqlens = buffers[3];
// output
void *dqkv = buffers[4];
void *dp = softmax_aux;
void *dbias = buffers[5];
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim};
auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// It's a little trick that the flash attn needs fwd output
// But when seqlen <= 512, it is not needed
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
// FP16/BF16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
// Currently, no rng_state required for bwd
auto rng_state = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt64);
// TODO: needs to think about how to pass aux_output_tensors
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
aux_output_tensors.size = 1;
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.shape = std::vector<size_t>{batch, num_head, q_max_seqlen, kv_max_seqlen};
output_s->data.dptr = softmax_aux;
TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16
&aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor,
descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type,
descriptor.mask_type, query_workspace_tensor.data(), stream);
size_t workspace_size =
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16
&aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor,
descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type,
descriptor.mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input
void *q = buffers[0];
void *kv = buffers[1];
void *q_cu_seqlens = buffers[2];
void *kv_cu_seqlens = buffers[3];
void *rng_state = buffers[4];
// output
void *output = buffers[5];
void *softmax_aux = buffers[6];
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
// TODO(rewang): add bias for cross attn?
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
// FP16/BF16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor =
TensorWrapper(output, std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{1}, DType::kInt64);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type,
query_workspace_tensor.data(), stream);
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux;
size_t workspace_size =
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type,
workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
void CrossFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input
void *q = buffers[0];
void *kv = buffers[1];
void *softmax_aux = buffers[2];
void *doutput = buffers[3];
void *q_cu_seqlens = buffers[4];
void *kv_cu_seqlens = buffers[5];
// output
void *dq = buffers[6];
void *dkv = buffers[7];
void *dp = softmax_aux;
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_head, head_dim};
auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// It's a little trick that the flash attn needs fwd output
// But when seqlen <= 512, it is not needed
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
// FP16/BF16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
// TODO(rewang): generalize cross attn
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
// Currently, no rng_state required for bwd
auto rng_state = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt64);
// TODO(rewang): need to think about how to pass aux_output_tensors
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
aux_output_tensors.size = 1;
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.shape = std::vector<size_t>{batch * num_head, q_max_seqlen, kv_max_seqlen};
output_s->data.dptr = softmax_aux;
TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16
&aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type,
query_workspace_tensor.data(), stream);
size_t workspace_size =
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16
&aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type,
workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include "transformer_engine/fused_attn.h"
#include "transformer_engine/logging.h" #include "transformer_engine/logging.h"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
...@@ -94,6 +95,27 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, ...@@ -94,6 +95,27 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch,
size_t q_seqlen, size_t k_seqlen, DType dtype, size_t q_seqlen, size_t k_seqlen, DType dtype,
float scale_factor); float scale_factor);
struct CustomCallFusedAttnDescriptor {
size_t batch;
size_t num_head;
size_t q_max_seqlen;
size_t kv_max_seqlen;
size_t head_dim;
float scaling_factor;
float dropout_probability;
NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type;
DType dtype;
bool is_training;
};
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, DType dtype, bool is_training);
bool IsFusedAttnKernelAvailable();
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
...@@ -144,6 +166,18 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, ...@@ -144,6 +166,18 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len); std::size_t opaque_len);
void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void CrossFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -75,6 +75,16 @@ class cudaDevicePropertiesManager { ...@@ -75,6 +75,16 @@ class cudaDevicePropertiesManager {
return prop_.multiProcessorCount; return prop_.multiProcessorCount;
} }
int GetMajor() {
if (!prop_queried_) {
int device_id;
NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
cudaGetDeviceProperties(&prop_, device_id);
prop_queried_ = true;
}
return prop_.major;
}
private: private:
bool prop_queried_ = false; bool prop_queried_ = false;
cudaDeviceProp prop_; cudaDeviceProp prop_;
......
...@@ -6,18 +6,24 @@ Wrapper module for Transformer related layers with FP8 support. ...@@ -6,18 +6,24 @@ Wrapper module for Transformer related layers with FP8 support.
""" """
import functools import functools
from enum import Enum from enum import Enum
from math import sqrt
from typing import Any, Callable, Optional, Sequence, Tuple, Union from typing import Any, Callable, Optional, Sequence, Tuple, Union
import warnings
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from flax import linen as nn from flax import linen as nn
from flax.linen import partitioning as nn_partitioning from flax.linen import partitioning as nn_partitioning
from jax import dtypes
from jax import nn as jax_nn from jax import nn as jax_nn
from jax import random as jax_random from jax import random as jax_random
from jax import lax, vmap from jax import lax, vmap
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax from .module import LayerNorm, Softmax
from ..fused_attn import AttnBiasType, AttnMaskType
from ..fused_attn import is_fused_attn_kernel_available
from ..fused_attn import self_fused_attn, cross_fused_attn
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
from ..sharding import infer_major_sharding_type, infer_sharding_type from ..sharding import infer_major_sharding_type, infer_sharding_type
from ..sharding import global_shard_resource, ShardingType from ..sharding import global_shard_resource, ShardingType
...@@ -129,6 +135,7 @@ def combine_biases(*masks: Optional[Array]): ...@@ -129,6 +135,7 @@ def combine_biases(*masks: Optional[Array]):
def core_attention(query: Array, def core_attention(query: Array,
key: Array, key: Array,
value: Array, value: Array,
scale_factor: float,
transpose_batch_sequence: bool, transpose_batch_sequence: bool,
softmax_type: SoftmaxType = SoftmaxType.SCALED, softmax_type: SoftmaxType = SoftmaxType.SCALED,
softmax_sharding_type: ShardingType = ShardingType.SINGLE, softmax_sharding_type: ShardingType = ShardingType.SINGLE,
...@@ -159,6 +166,7 @@ def core_attention(query: Array, ...@@ -159,6 +166,7 @@ def core_attention(query: Array,
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
attn_weights = Softmax(softmax_type=softmax_type, attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=scale_factor,
sharding_type=softmax_sharding_type)(attn_weights, mask, bias) sharding_type=softmax_sharding_type)(attn_weights, mask, bias)
if not deterministic and dropout_rate > 0.: if not deterministic and dropout_rate > 0.:
...@@ -181,8 +189,8 @@ dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, N ...@@ -181,8 +189,8 @@ dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, N
class AttentionType(Enum): class AttentionType(Enum):
"""TransformerLayerType.""" """TransformerLayerType."""
PADDING = "padding_attention" PADDING = AttnMaskType.PADDING_MASK
CAUSAL = "causal_attention" CAUSAL = AttnMaskType.CAUSAL_MASK
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
...@@ -312,9 +320,8 @@ class MultiHeadAttention(nn.Module): ...@@ -312,9 +320,8 @@ class MultiHeadAttention(nn.Module):
Output tensors. Output tensors.
""" """
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
def query_init(*args): def query_init(*args):
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0) return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)
def qkv_init(key, shape, dtype): def qkv_init(key, shape, dtype):
...@@ -349,6 +356,43 @@ class MultiHeadAttention(nn.Module): ...@@ -349,6 +356,43 @@ class MultiHeadAttention(nn.Module):
first_sharding_type, second_sharding_type = infer_sharding_type() first_sharding_type, second_sharding_type = infer_sharding_type()
canonicalize_dtype = dtypes.canonicalize_dtype(self.dtype)
q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1]
kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1]
fused_attn_supported_seqlen = [128, 256, 384, 512]
use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
self.dropout_rate == 0 and canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \
q_seqlen in fused_attn_supported_seqlen and kv_seqlen in fused_attn_supported_seqlen \
and is_fused_attn_kernel_available()
if not use_fused_attn:
reason = ""
if decode:
reason += f"decode=False is required but got {decode}, "
if self.transpose_batch_sequence:
reason += f"transpose_batch_sequence=False is required " \
f"but got {self.transpose_batch_sequence}, "
if not self.fuse_qkv:
reason += f"fuse_qkv=True is required but got {self.fuse_qkv}, "
if self.dropout_rate != 0:
# TODO(rewang): add dropout support
reason += f"no dropout is required but got dropout_rate={self.dropout_rate}, "
if canonicalize_dtype not in [jnp.bfloat16, jnp.float16]:
reason += f"dtype in [BF16, FP16] is required " \
f"but got dtype={canonicalize_dtype}, "
if q_seqlen not in fused_attn_supported_seqlen:
reason += f"q_seqlen in {fused_attn_supported_seqlen} is required " \
f"but got {q_seqlen=}, "
if kv_seqlen not in fused_attn_supported_seqlen:
reason += f"kv_seqlen in {fused_attn_supported_seqlen} is required " \
f"but got {kv_seqlen=}, "
if not is_fused_attn_kernel_available():
reason += "GPU arch >= Ampere and cuDNN >= 8.9.1 are required, "
warnings.warn(
f"Fused attention is not enabled, " \
f"{reason}fall back to unfused attention")
residual = inputs_q residual = inputs_q
if self.fuse_qkv: if self.fuse_qkv:
if inputs_q is inputs_kv: if inputs_q is inputs_kv:
...@@ -369,12 +413,8 @@ class MultiHeadAttention(nn.Module): ...@@ -369,12 +413,8 @@ class MultiHeadAttention(nn.Module):
bias_init=self.bias_init, bias_init=self.bias_init,
name='qkv', name='qkv',
dtype=self.dtype)(inputs_q) dtype=self.dtype)(inputs_q)
query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2) if not use_fused_attn:
query = jnp.reshape(query, (*query.shape[:-2], -1)) query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
key = jnp.reshape(key, (*key.shape[:-2], -1))
value = jnp.reshape(value, (*value.shape[:-2], -1))
if self.scale_attn_logits:
query = query / depth_scaling
else: else:
query, ln_out = LayerNormDenseGeneral( query, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm, enable_layernorm=not self.output_layernorm,
...@@ -386,7 +426,6 @@ class MultiHeadAttention(nn.Module): ...@@ -386,7 +426,6 @@ class MultiHeadAttention(nn.Module):
sharding_type=first_sharding_type, sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm, return_layernorm_output=self.apply_residual_connection_post_layernorm,
depth_scaling=depth_scaling if self.scale_attn_logits else None,
scale_axes=('embed',), scale_axes=('embed',),
kernel_axes=('embed', 'joined_kv'), kernel_axes=('embed', 'joined_kv'),
use_bias=self.use_bias, use_bias=self.use_bias,
...@@ -404,11 +443,8 @@ class MultiHeadAttention(nn.Module): ...@@ -404,11 +443,8 @@ class MultiHeadAttention(nn.Module):
bias_init=self.bias_init, bias_init=self.bias_init,
name='kv', name='kv',
dtype=self.dtype)(inputs_kv) dtype=self.dtype)(inputs_kv)
key, value = jnp.split(kv_proj, [ if not use_fused_attn:
1, key, value = jnp.split(kv_proj, [1], axis=-2)
], axis=-2)
key = jnp.reshape(key, (*key.shape[:-2], -1))
value = jnp.reshape(value, (*value.shape[:-2], -1))
else: else:
kv_projection = functools.partial( kv_projection = functools.partial(
DenseGeneral, DenseGeneral,
...@@ -430,7 +466,6 @@ class MultiHeadAttention(nn.Module): ...@@ -430,7 +466,6 @@ class MultiHeadAttention(nn.Module):
sharding_type=first_sharding_type, sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=True, return_layernorm_output=True,
depth_scaling=depth_scaling if self.scale_attn_logits else None,
scale_axes=('embed',), scale_axes=('embed',),
kernel_axes=('embed', 'joined_kv'), kernel_axes=('embed', 'joined_kv'),
use_bias=self.use_bias, use_bias=self.use_bias,
...@@ -446,21 +481,21 @@ class MultiHeadAttention(nn.Module): ...@@ -446,21 +481,21 @@ class MultiHeadAttention(nn.Module):
key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv) key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv) value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim))
key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim))
value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim))
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
assert ln_out is not None assert ln_out is not None
residual = ln_out residual = ln_out
qkv_sharding_constraint = \ if not use_fused_attn:
('length', 'batch', 'heads','kv') \ query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim))
if self.transpose_batch_sequence \ key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim))
else ('batch', 'length', 'heads', 'kv') value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim))
query = nn_partitioning.with_sharding_constraint(query, qkv_sharding_constraint) qkv_sharding_constraint = \
key = nn_partitioning.with_sharding_constraint(key, qkv_sharding_constraint) ('length', 'batch', 'heads','kv') \
value = nn_partitioning.with_sharding_constraint(value, qkv_sharding_constraint) if self.transpose_batch_sequence \
else ('batch', 'length', 'heads', 'kv')
query = nn_partitioning.with_sharding_constraint(query, qkv_sharding_constraint)
key = nn_partitioning.with_sharding_constraint(key, qkv_sharding_constraint)
value = nn_partitioning.with_sharding_constraint(value, qkv_sharding_constraint)
if decode: if decode:
is_initialized = self.has_variable('cache', 'cached_key') is_initialized = self.has_variable('cache', 'cached_key')
...@@ -502,30 +537,74 @@ class MultiHeadAttention(nn.Module): ...@@ -502,30 +537,74 @@ class MultiHeadAttention(nn.Module):
bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0), bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
jnp.reshape(cur_index, (-1)), 1, -2) jnp.reshape(cur_index, (-1)), 1, -2)
scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
dropout_rng = None dropout_rng = None
if not deterministic and self.dropout_rate > 0.: if not deterministic and self.dropout_rate > 0.:
dropout_rng = self.make_rng(self.dropout_rng_name) dropout_rng = self.make_rng(self.dropout_rng_name)
softmax_type = SoftmaxType.SCALED if use_fused_attn:
if self.attn_type is AttentionType.PADDING: assert mask is not None and mask.ndim == 4 # (b, 1, s_q, s_kv)
if mask is not None: assert not self.transpose_batch_sequence
softmax_type = SoftmaxType.SCALED_MASKED # TODO(rewang): make it configurable for pre_scale_bias
else: attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
softmax_type = SoftmaxType.SCALED_UPPER_TRIANG_MASKED
x = core_attention(query, if inputs_q is inputs_kv:
key, qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
value, qkv_sharding_constraint = ('batch', 'length', 'qkv_dim', 'heads', 'kv')
transpose_batch_sequence=self.transpose_batch_sequence, qkv_proj = nn_partitioning.with_sharding_constraint(qkv_proj,
softmax_type=softmax_type, qkv_sharding_constraint)
softmax_sharding_type=first_sharding_type, x = self_fused_attn(qkv_proj,
mask=mask, bias,
bias=bias, mask,
dropout_rng=dropout_rng, dropout_rng,
dropout_rate=self.dropout_rate, attn_bias_type=attn_bias_type,
deterministic=deterministic, attn_mask_type=self.attn_type.value,
dtype=self.dtype, scaling_factor=scale_factor,
float32_logits=self.float32_logits) dropout_probability=self.dropout_rate,
is_training=not deterministic,
sharding_type=first_sharding_type)
else:
assert bias is None
query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim))
kv_proj = kv_proj.reshape((*kv_proj.shape[:-1], self.num_heads, self.head_dim))
q_sharding_constraint = ('batch', 'length', 'heads', 'kv')
kv_sharding_constraint = ('batch', 'length', 'kv_dim', 'heads', 'kv')
query = nn_partitioning.with_sharding_constraint(query, q_sharding_constraint)
kv_proj = nn_partitioning.with_sharding_constraint(kv_proj, kv_sharding_constraint)
x = cross_fused_attn(query,
kv_proj,
mask,
dropout_rng,
attn_bias_type=attn_bias_type,
attn_mask_type=self.attn_type.value,
scaling_factor=scale_factor,
dropout_probability=self.dropout_rate,
is_training=not deterministic,
sharding_type=first_sharding_type)
else:
softmax_type = SoftmaxType.SCALED
if self.attn_type is AttentionType.PADDING:
if mask is not None:
softmax_type = SoftmaxType.SCALED_MASKED
else:
softmax_type = SoftmaxType.SCALED_UPPER_TRIANG_MASKED
x = core_attention(query,
key,
value,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
softmax_type=softmax_type,
softmax_sharding_type=first_sharding_type,
mask=mask,
bias=bias,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
deterministic=deterministic,
dtype=self.dtype,
float32_logits=self.float32_logits)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX multi-head attention modules"""
from enum import Enum
from functools import partial
import jax
import jax.numpy as jnp
import transformer_engine_jax
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from .cpp_extensions import cross_fused_attn_max_512_fwd, cross_fused_attn_max_512_bwd
from .cpp_extensions import self_fused_attn_max_512_fwd, self_fused_attn_max_512_bwd
from .sharding import get_fused_attn_sharding_meta
from .sharding import ShardingType
from .sharding import xmap_runner
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
def is_fused_attn_kernel_available():
"""
To check whether the fused attention kernel is available
"""
return transformer_engine_jax.is_fused_attn_kernel_available()
class AttnBiasType(Enum):
"""Attention Bias Type."""
NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS
class AttnMaskType(Enum):
"""Attention Mask Type."""
NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
def self_fused_attn(qkv: jnp.ndarray,
bias: jnp.ndarray,
mask: jnp.ndarray,
rng_state: jnp.ndarray,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
sharding_type: ShardingType = ShardingType.SINGLE):
"""
Self fused attention wrapper
"""
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"Fused_attn_max_512 does not support row-split tensor parallelism currently."
if sharding_type is ShardingType.SINGLE:
output = _self_fused_attn_max_512(qkv,
bias,
mask,
rng_state,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
inputs = [qkv, bias, mask, rng_state]
batch, seqlen, _, num_head, head_dim = qkv.shape
output_shape = [batch, seqlen, num_head, head_dim]
sharding_meta = get_fused_attn_sharding_meta(
sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
dp_dims=([0, None, 0, None], [0]),
tp_dims=([3, 1, None, None], [2]),
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
inputs_ = tuple(
jnp.reshape(x, new_shape) if x is not None else None
for x, new_shape in zip(inputs, sharding_meta.input_shapes))
partial_self_fused_attn_max_512 = partial(_self_fused_attn_max_512,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output_ = xmap_runner(partial_self_fused_attn_max_512, sharding_meta.in_axes,
sharding_meta.out_axes[0], sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes[0])
return output
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _self_fused_attn_max_512(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
rng_state: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
output, _ = _self_fused_attn_max_512_fwd(qkv,
bias,
mask,
rng_state,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output
def _self_fused_attn_max_512_fwd(qkv, bias, mask, rng_state, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
cu_seqlen = jnp.cumsum(seqlen)
cu_seqlen = jnp.hstack((0, cu_seqlen))
output, softmax_aux = self_fused_attn_max_512_fwd(qkv,
bias,
cu_seqlen,
rng_state,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output, (qkv, softmax_aux, cu_seqlen)
def _self_fused_attn_max_512_bwd(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, ctx, grad):
qkv, softmax_aux, cu_seqlen = ctx
doutput = grad
grad_qkv, grad_bias = self_fused_attn_max_512_bwd(qkv,
softmax_aux,
doutput,
cu_seqlen,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return grad_qkv, grad_bias, None, None
_self_fused_attn_max_512.defvjp(_self_fused_attn_max_512_fwd, _self_fused_attn_max_512_bwd)
def cross_fused_attn(q: jnp.ndarray,
kv: jnp.ndarray,
mask: jnp.ndarray,
rng_state: jnp.ndarray,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
sharding_type: ShardingType = ShardingType.SINGLE):
"""
Cross multi-head attention wrapper
"""
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"Fused_attn_max_512 does not support row-split tensor parallelism currently."
if sharding_type is ShardingType.SINGLE:
output = _cross_fused_attn_max_512(q,
kv,
mask,
rng_state,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
inputs = [q, kv, mask, rng_state]
output_shape = q.shape
sharding_meta = get_fused_attn_sharding_meta(
sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
dp_dims=([0, 0, 0, None], [0]),
tp_dims=([2, 3, None, None], [2]),
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
inputs_ = tuple(
jnp.reshape(x, new_shape) if x is not None else None
for x, new_shape in zip(inputs, sharding_meta.input_shapes))
partial_cross_fused_attn_max_512 = partial(_cross_fused_attn_max_512,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output_ = xmap_runner(partial_cross_fused_attn_max_512, sharding_meta.in_axes,
sharding_meta.out_axes[0], sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes[0])
return output
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _cross_fused_attn_max_512(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray,
rng_state: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
output, _ = _cross_fused_attn_max_512_fwd(q,
kv,
mask,
rng_state,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output
def _cross_fused_attn_max_512_fwd(q, kv, mask, rng_state, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
q_seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
q_cu_seqlen = jnp.cumsum(q_seqlen)
q_cu_seqlen = jnp.hstack((0, q_cu_seqlen))
kv_seqlen = jnp.sum(mask[:, :, 0, :] == 0, axis=(-1, -2), dtype=jnp.int32)
kv_cu_seqlen = jnp.cumsum(kv_seqlen)
kv_cu_seqlen = jnp.hstack((0, kv_cu_seqlen))
output, softmax_aux = cross_fused_attn_max_512_fwd(q,
kv,
q_cu_seqlen,
kv_cu_seqlen,
rng_state,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output, (softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen)
def _cross_fused_attn_max_512_bwd(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, ctx, grad):
softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen = ctx
doutput = grad
grad_q, grad_kv = cross_fused_attn_max_512_bwd(q,
kv,
softmax_aux,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return grad_q, grad_kv, None, None
_cross_fused_attn_max_512.defvjp(_cross_fused_attn_max_512_fwd, _cross_fused_attn_max_512_bwd)
...@@ -8,6 +8,7 @@ Sharding Meta for xmap with CustomCall ...@@ -8,6 +8,7 @@ Sharding Meta for xmap with CustomCall
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from itertools import repeat
from typing import Union, Tuple, Dict, Callable, Sequence from typing import Union, Tuple, Dict, Callable, Sequence
from jax.interpreters import pxla from jax.interpreters import pxla
import jax import jax
...@@ -315,6 +316,121 @@ class FP8MetaShardingMetaGenerator(ShardingMetaGenerator): ...@@ -315,6 +316,121 @@ class FP8MetaShardingMetaGenerator(ShardingMetaGenerator):
axis_resource, (), ()) axis_resource, (), ())
class FusedAttnShardingMetaGenerator(ShardingMetaGenerator):
"""
FusedAttnShardingMetaGenerator
"""
def get_dp_sharding_meta(
self,
input_shapes: Tuple[Tuple[int, ...]],
output_shapes: Tuple[Tuple[int, ...]],
dp_dims: Tuple[Tuple[int, ...]],
tp_dims: Tuple[Tuple[int, ...]], # pylint: disable=unused-argument
dp_axis_name: str = 'data',
tp_axis_name: str = 'model' # pylint: disable=unused-argument
) -> ShardingMeta:
"""get_dp_sharding_meta"""
dummy_tp_dims = [repeat(None), repeat(None)]
return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(input_shapes, output_shapes,
dp_dims, dummy_tp_dims,
dp_axis_name, None)
def get_tp_col_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_tp_col_sharding_meta"""
return FusedAttnShardingMetaGenerator._get_tp_sharding_meta(*argv, **kwargs)
def get_tp_row_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_tp_row_sharding_meta"""
return FusedAttnShardingMetaGenerator._get_tp_sharding_meta(*argv, **kwargs)
def get_dp_tp_col_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_dp_tp_col_sharding_meta"""
return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(*argv, **kwargs)
def get_dp_tp_row_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_dp_tp_row_sharding_meta"""
return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(*argv, **kwargs)
@staticmethod
def _get_tp_sharding_meta(
input_shapes: Tuple[Tuple[int, ...]],
output_shapes: Tuple[Tuple[int, ...]],
dp_dims: Tuple[Tuple[int, ...]], # pylint: disable=unused-argument
tp_dims: Tuple[Tuple[int, ...]],
dp_axis_name: str = 'data', # pylint: disable=unused-argument
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_tp_sharding_meta"""
dummy_dp_dims = [repeat(None), repeat(None)]
return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(input_shapes, output_shapes,
dummy_dp_dims, tp_dims, None,
tp_axis_name)
@staticmethod
def _get_dptp_sharding_meta(input_shapes: Tuple[Tuple[int, ...]],
output_shapes: Tuple[Tuple[int, ...]],
dp_dims: Tuple[Tuple[int, ...]],
tp_dims: Tuple[Tuple[int, ...]],
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_dp_tp_sharding_meta"""
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
input_dp_dims, output_dp_dims = dp_dims
input_tp_dims, output_tp_dims = tp_dims
input_new_shapes = []
in_axes = []
for input_shape, dp_dim, tp_dim in zip(input_shapes, input_dp_dims, input_tp_dims):
in_axis = {}
if dp_dim is not None:
in_axis[dp_dim] = dp_axis_name
assert input_shape[dp_dim] % dp_size == 0, \
f"The dimension of batch in input_shape should be a multiple of " \
f"data parallelism size, but got {input_shape[dp_dim]=} and {dp_size=}."
input_shape = (*input_shape[:dp_dim], dp_size, input_shape[dp_dim] // dp_size,
*input_shape[dp_dim + 1:])
# the input shape has been expanded for dp_dim, tp_dim should +1 if tp_dim >= dp_dim
if tp_dim is not None and tp_dim >= dp_dim:
tp_dim = tp_dim + 1
if tp_dim is not None:
in_axis[tp_dim] = tp_axis_name
assert input_shape[tp_dim] % tp_size == 0, \
f"The dimension of tensor parallel in input_shape should be a multiple of " \
f"tensor parallelism size, but got {input_shape[tp_dim]=} and {tp_size=}."
input_shape = (*input_shape[:tp_dim], tp_size, input_shape[tp_dim] // tp_size,
*input_shape[tp_dim + 1:])
in_axes.append(in_axis)
input_new_shapes.append(input_shape)
output_new_shapes = output_shapes
out_axes = []
for dp_dim, tp_dim in zip(output_dp_dims, output_tp_dims):
out_axis = {}
if dp_dim is not None:
out_axis[dp_dim] = dp_axis_name
if tp_dim is not None and tp_dim >= dp_dim:
tp_dim = tp_dim + 1
if tp_dim is not None:
out_axis[tp_dim] = tp_axis_name
out_axes.append(out_axis)
axis_resources = {}
if dp_axis_name is not None:
axis_resources[dp_axis_name] = dp_mesh_axis
if tp_axis_name is not None:
axis_resources[tp_axis_name] = tp_mesh_axis
return ShardingMeta(tuple(in_axes), out_axes, axis_resources, input_new_shapes,
output_new_shapes)
class DotShardingMetaGenerator(ShardingMetaGenerator): class DotShardingMetaGenerator(ShardingMetaGenerator):
""" """
DotShardingMetaGenerator DotShardingMetaGenerator
...@@ -884,6 +1000,21 @@ def get_softmax_sharding_meta(stype: ShardingType, ...@@ -884,6 +1000,21 @@ def get_softmax_sharding_meta(stype: ShardingType,
dp_axis_name, tp_axis_name) dp_axis_name, tp_axis_name)
def get_fused_attn_sharding_meta(stype: ShardingType,
input_shapes: Tuple[Tuple[int, ...]],
output_shapes: Tuple[Tuple[int, ...]],
dp_dims: Tuple[Tuple[int, ...]],
tp_dims: Tuple[Tuple[int, ...]],
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""
get_self_fused_attn_sharding_meta
"""
return FusedAttnShardingMetaGenerator().get_sharding_meta(stype, input_shapes, output_shapes,
dp_dims, tp_dims, dp_axis_name,
tp_axis_name)
def xmap_runner(func: Callable, in_axes: Tuple[Dict, ...], def xmap_runner(func: Callable, in_axes: Tuple[Dict, ...],
out_axes: Union[Dict, Tuple[str, ...], Tuple[Union[Dict, Tuple], ...]], out_axes: Union[Dict, Tuple[str, ...], Tuple[Union[Dict, Tuple], ...]],
axis_resources: Dict, inputs: Tuple): axis_resources: Dict, inputs: Tuple):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment