Unverified Commit 66ff2e36 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX] flash attention integration (#345)



* Fix flash attention dropout probability with inference
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add output as the fused attention ctx tensor
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add rng_state as the fused attention ctx tensors
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add flash attention supported lengths to the fused attention
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refactor attention primitive to reuse abstract shaped array
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Detect backend type to allocate appropriate ctx size
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Skip dropout correctness instead of return success
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use cudaMemsetAsync and enhance the error handling
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add flash attention kernel elts_per_thread update
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove redundant max 512 suffix
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Keep only DType and remove NVTEDType from python
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix a float32_attention_logits bugs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Re-calculate workspace size for self attention
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance bias/dbias shape guard
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance the seed/rng_state checker
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use jax.core.ShapedArray as jax.abstract_arrays is deprecated
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance the unittest docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 403ade2f
......@@ -4,7 +4,7 @@
import pytest
import jax.numpy as jnp
from jax.abstract_arrays import ShapedArray
from jax.core import ShapedArray
from transformer_engine_jax import DType
from transformer_engine.jax.cpp_extensions import te_dtype_to_jax_dtype
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for fused attention"""
from typing import Optional
import math
import os
from enum import Enum
from math import sqrt
import jax
import jax.numpy as jnp
......@@ -14,8 +16,6 @@ from flax.linen import combine_masks
from flax.linen import dot_product_attention
from flax.linen import make_attention_mask
from flax.linen import make_causal_mask
from jax import lax
from jax import nn as jax_nn
from jax import value_and_grad, jit
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType
......@@ -25,19 +25,45 @@ from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
# Type annotations
Array = jnp.ndarray
SELF_CASES = [(32, 512, 16, 64), (32, 128, 16, 64)]
class Backend(Enum):
"""
Fused attn backend.
Unit tests only, transformer will auto dispatch to the best backend
"""
Max512 = "0"
Arbitrary = "1"
@pytest.fixture(name="backend", params=[Backend.Max512, Backend.Arbitrary])
def fixture_backend(request):
"""
Fixture of setting up/tearing down backend
"""
backend = request.param
os.environ["NVTE_FUSED_ATTN_BACKEND"] = backend.value
yield backend
os.environ["NVTE_FUSED_ATTN_BACKEND"] = ""
SELF_CASES = [(32, 512, 16, 64), (32, 128, 16, 64), (4, 2048, 12, 64)]
CROSS_CASES = [(32, 128, 512, 16, 64)]
DTYPES = [jnp.bfloat16, jnp.float16]
PAD_RATIO = [0.3]
def make_decoder_mask(tokens: Array) -> Array:
"""
Create padded causal mask
"""
causal_mask = make_causal_mask(tokens)
padding_mask = make_attention_mask(tokens > 0, tokens > 0)
return combine_masks(causal_mask, padding_mask)
def jax_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
"""
Self attention with JAX native implementation
"""
attn_mask_type = kwargs['attn_mask_type']
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_decoder_mask(q_token)
......@@ -61,7 +87,10 @@ def jax_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
return output
def jax_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
"""
Cross attention with JAX native implementation
"""
assert q.dtype == kv.dtype
attn_mask_type = kwargs['attn_mask_type']
......@@ -87,6 +116,9 @@ def jax_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
def customcall_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
"""
Self fused attention
"""
if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK:
mask = make_decoder_mask(q_token)
else:
......@@ -99,6 +131,9 @@ def customcall_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwar
def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
"""
Cross fused attention
"""
assert q.dtype == kv.dtype
if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK:
......@@ -113,10 +148,33 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)
@pytest.mark.skipif(not is_fused_attn_kernel_available(),
reason="Fused attention kernel is not supported.")
class TestSelfFusedAttnMax512():
def set_input(self, b, s, h, d, *, attn_bias_type, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', [0, 0.3])
class TestSelfFusedAttn():
"""Tests for transformer_engine.jax.fused_attn.self_fused_attn"""
@staticmethod
def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, pad_ratio):
# Arbitrary seqlen backend has a limited spec for now
# No bias, only causal mask, and no variable seqlen
if (s > 512 or backend == Backend.Arbitrary) and (attn_bias_type != AttnBiasType.NO_BIAS or
attn_mask_type != AttnMaskType.CAUSAL_MASK
or pad_ratio != 0):
pytest.skip("Unsupported inputs combination.")
def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
dropout_probability, dtype, is_training, pad_ratio):
"""Setup the test inputs"""
self.__class__._check_inputs(s,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
backend=backend,
pad_ratio=pad_ratio)
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
......@@ -137,78 +195,25 @@ class TestSelfFusedAttnMax512():
axis=-1)
self.kv_token = self.q_token
self.scaling_factor = 1. / math.sqrt(d)
self.scaling_factor = 1. / sqrt(d)
self.dropout_probability = dropout_probability
self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None
self.attn_bias_type = attn_bias_type
self.attn_mask_type = attn_mask_type
self.is_training = is_training
@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_sanity(self, b, s, h, d, attn_bias_type, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
def grad_func(func, *args, **kwargs):
# Keep only valid result for the gradient
# fused_attn_max_512 output has shape (b, s, h, d)
valid_ret, _ = jnp.split(func(*args, **kwargs), (self.valid_len,), axis=1)
return jnp.mean(valid_ret, dtype=jnp.float32).astype(dtype)
self.set_input(b,
s,
h,
d,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
kwargs = {
'attn_bias_type': self.attn_bias_type,
'attn_mask_type': attn_mask_type,
'scaling_factor': self.scaling_factor,
'dropout_probability': self.dropout_probability,
'is_training': self.is_training
}
jitted_primitive = jit(
value_and_grad(
lambda qkv, bias, q_token, kv_token, dropout_rng: grad_func(
customcall_self_fused_attn, qkv, bias, q_token, kv_token, dropout_rng, **kwargs
), (0, 1)))
primitive_out, (primitive_dqkv,
primitive_dbias) = jitted_primitive(self.qkv, self.bias, self.q_token,
self.kv_token, self.dropout_rng)
@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
# dropout can't get the bitmatch result
if is_training and dropout_probability > 0.:
return
self.set_input(b,
def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, dropout_probability,
dtype, is_training, pad_ratio):
"""
Test forward without using JIT
"""
self._set_inputs(b,
s,
h,
d,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
......@@ -225,7 +230,7 @@ class TestSelfFusedAttnMax512():
dropout_probability=self.dropout_probability,
is_training=self.is_training)
reference_out = jax_self_fused_attn(self.qkv,
reference_out = jax_self_attn(self.qkv,
self.bias,
self.q_token,
self.kv_token,
......@@ -238,6 +243,10 @@ class TestSelfFusedAttnMax512():
ref_valid, _ = jnp.split(reference_out, (self.valid_len,), axis=1)
pri_valid, pri_invalid = jnp.split(primitive_out, (self.valid_len,), axis=1)
# Dropout can't get the bitmatch result, skip the elementwise comparison
if is_training and dropout_probability > 0.:
return
np.testing.assert_allclose(jnp.asarray(pri_valid, np.float32),
jnp.asarray(ref_valid, np.float32),
rtol=1e-4,
......@@ -246,38 +255,36 @@ class TestSelfFusedAttnMax512():
np.testing.assert_allclose(jnp.asarray(pri_invalid, jnp.float32),
jnp.zeros_like(pri_invalid, jnp.float32))
@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dropout_probability', [0.]) # dropout can't get the bitmatch result
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True]) # backward is only used when is_training
@pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, dropout_probability,
dtype, is_training, pad_ratio):
self.set_input(b,
def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend,
dropout_probability, dtype, is_training, pad_ratio):
"""
Test forward, backward, and autodiff by jax.value_and_grad
"""
if not is_training:
pytest.skip(f"Backward doesn't support {is_training=}")
self._set_inputs(b,
s,
h,
d,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
def grad_func(fused_attn_max_512_func, *args, **kwargs):
def grad_func(fused_attn_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident
gradient_multiplier = 1000 if dtype == jnp.bfloat16 else 10000
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
gradient_multiplier = gradient_multiplier / 10
# Keep only valid result for the gradient
# fused_attn_max_512 output has shape (b, s, h, d)
valid_fused_attn_max_512_ret, _ = jnp.split(fused_attn_max_512_func(*args, **kwargs),
(self.valid_len,),
# fused_attn output has shape (b, s, h, d)
valid_fused_attn_ret, _ = jnp.split(fused_attn_func(*args, **kwargs), (self.valid_len,),
axis=1)
return (jnp.mean(valid_fused_attn_max_512_ret, dtype=jnp.float32) *
return (jnp.mean(valid_fused_attn_ret, dtype=jnp.float32) *
gradient_multiplier).astype(dtype)
kwargs = {
......@@ -298,8 +305,7 @@ class TestSelfFusedAttnMax512():
jitted_reference = jit(
value_and_grad(
lambda qkv, bias, q_token, kv_token, dropout_rng: grad_func(
jax_self_fused_attn, qkv, bias, q_token, kv_token, dropout_rng, **kwargs),
(0, 1)))
jax_self_attn, qkv, bias, q_token, kv_token, dropout_rng, **kwargs), (0, 1)))
primitive_out, (primitive_dqkv,
primitive_dbias) = jitted_primitive(self.qkv, self.bias, self.q_token,
......@@ -309,6 +315,10 @@ class TestSelfFusedAttnMax512():
reference_dbias) = jitted_reference(self.qkv, self.bias, self.q_token,
self.kv_token, self.dropout_rng)
# Dropout can't get the bitmatch result, skip the elementwise comparison
if dropout_probability > 0.:
return
np.testing.assert_allclose(jnp.asarray(primitive_out, np.float32),
jnp.asarray(reference_out, np.float32),
rtol=1e-4,
......@@ -319,23 +329,14 @@ class TestSelfFusedAttnMax512():
valid_reference_dqkv, invalid_reference_dqkv = jnp.split(reference_dqkv, (self.valid_len,),
axis=1)
# dQ
np.testing.assert_allclose(jnp.asarray(valid_primitive_dqkv[:, :, 0], np.float32),
jnp.asarray(valid_reference_dqkv[:, :, 0], np.float32),
rtol=1e-4,
atol=1e-5)
valid_primitive_dq, valid_primitive_dk, valid_primitive_dv = jnp.split(
valid_primitive_dqkv.astype(jnp.float32), 3, axis=2)
valid_reference_dq, valid_reference_dk, valid_reference_dv = jnp.split(
valid_reference_dqkv.astype(jnp.float32), 3, axis=2)
# dK
np.testing.assert_allclose(jnp.asarray(valid_primitive_dqkv[:, :, 1], np.float32),
jnp.asarray(valid_reference_dqkv[:, :, 1], np.float32),
rtol=1e-4,
atol=1e-5)
# dV
np.testing.assert_allclose(jnp.asarray(valid_primitive_dqkv[:, :, 2], np.float32),
jnp.asarray(valid_reference_dqkv[:, :, 2], np.float32),
rtol=1e-4,
atol=1e-5)
np.testing.assert_allclose(valid_primitive_dq, valid_reference_dq, rtol=1e-4, atol=1e-5)
np.testing.assert_allclose(valid_primitive_dk, valid_reference_dk, rtol=1e-4, atol=1e-5)
np.testing.assert_allclose(valid_primitive_dv, valid_reference_dv, rtol=1e-4, atol=1e-5)
assert jnp.allclose(invalid_primitive_dqkv, invalid_reference_dqkv)
......@@ -362,9 +363,16 @@ class TestSelfFusedAttnMax512():
@pytest.mark.skipif(not is_fused_attn_kernel_available(),
reason="Fused attention kernel is not supported.")
class TestCrossFusedAttnMax512():
def set_input(self, b, s_q, s_kv, h, d, *, attn_mask_type, dropout_probability, dtype,
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES)
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', [0.3])
class TestCrossFusedAttn():
"""Tests for transformer_engine.jax.fused_attn.cross_fused_attn"""
def _set_inputs(self, b, s_q, s_kv, h, d, *, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
......@@ -385,25 +393,19 @@ class TestCrossFusedAttnMax512():
self.kv_token = jnp.concatenate((jnp.ones((b, self.kv_valid_len)), jnp.zeros(
(b, kv_pad_len))),
axis=-1)
self.scaling_factor = 1. / math.sqrt(d)
self.scaling_factor = 1. / sqrt(d)
self.dropout_probability = dropout_probability
self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None
self.attn_bias_type = AttnBiasType.NO_BIAS
self.attn_mask_type = attn_mask_type
self.is_training = is_training
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES)
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
# dropout can't get the bitmatch result
if is_training and dropout_probability > 0.:
return
self.set_input(b,
"""
Test forward without using JIT
"""
self._set_inputs(b,
s_q,
s_kv,
h,
......@@ -425,7 +427,7 @@ class TestCrossFusedAttnMax512():
dropout_probability=self.dropout_probability,
is_training=self.is_training)
reference_out = jax_cross_fused_attn(self.q,
reference_out = jax_cross_attn(self.q,
self.kv,
self.q_token,
self.kv_token,
......@@ -435,6 +437,10 @@ class TestCrossFusedAttnMax512():
dropout_probability=self.dropout_probability,
is_training=self.is_training)
# Dropout can't get the bitmatch result, skip the elementwise comparison
if is_training and dropout_probability > 0.:
return
ref_valid, _ = jnp.split(reference_out, (self.q_valid_len,), axis=1)
pri_valid, pri_invalid = jnp.split(primitive_out, (self.q_valid_len,), axis=1)
......@@ -446,15 +452,15 @@ class TestCrossFusedAttnMax512():
np.testing.assert_allclose(jnp.asarray(pri_invalid, jnp.float32),
jnp.zeros_like(pri_invalid, jnp.float32))
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES)
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK])
@pytest.mark.parametrize('dropout_probability', [0.]) # dropout can't get the bitmatch result
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True]) # backward is only used when is_training
@pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward_backward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
self.set_input(b,
"""
Test forward, backward, and autodiff by jax.value_and_grad
"""
if not is_training:
pytest.skip(f"Backward doesn't support {is_training=}")
self._set_inputs(b,
s_q,
s_kv,
h,
......@@ -465,17 +471,17 @@ class TestCrossFusedAttnMax512():
is_training=is_training,
pad_ratio=pad_ratio)
def grad_func(fused_attn_max_512_func, *args, **kwargs):
def grad_func(fused_attn_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident
gradient_multiplier = 10000
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
gradient_multiplier = gradient_multiplier / 10
# Keep only valid result for the gradient
# fused_attn_max_512 output has shape (b, s_q, h, d)
valid_fused_attn_max_512_ret, _ = jnp.split(fused_attn_max_512_func(*args, **kwargs),
# fused_attn output has shape (b, s_q, h, d)
valid_fused_attn_ret, _ = jnp.split(fused_attn_func(*args, **kwargs),
(self.q_valid_len,),
axis=1)
return (jnp.mean(valid_fused_attn_max_512_ret, dtype=jnp.float32) *
return (jnp.mean(valid_fused_attn_ret, dtype=jnp.float32) *
gradient_multiplier).astype(dtype)
kwargs = {
......@@ -496,7 +502,7 @@ class TestCrossFusedAttnMax512():
jitted_reference = jit(
value_and_grad(
lambda q, kv, q_token, kv_token, dropout_rng: grad_func(
jax_cross_fused_attn, q, kv, q_token, kv_token, dropout_rng, **kwargs), (0, 1)))
jax_cross_attn, q, kv, q_token, kv_token, dropout_rng, **kwargs), (0, 1)))
primitive_out, (primitive_dq,
primitive_dkv) = jitted_primitive(self.q, self.kv, self.q_token,
......@@ -506,6 +512,10 @@ class TestCrossFusedAttnMax512():
reference_dkv) = jitted_reference(self.q, self.kv, self.q_token,
self.kv_token, self.dropout_rng)
# Dropout can't get the bitmatch result, skip the elementwise comparison
if dropout_probability > 0.:
return
np.testing.assert_allclose(jnp.asarray(primitive_out, np.float32),
jnp.asarray(reference_out, np.float32),
rtol=1e-4,
......
......@@ -547,7 +547,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
if (!is_training) {
dropout_probability == 0.0f;
dropout_probability = 0.0f;
}
FADescriptor descriptor{b, h,
......@@ -1144,7 +1144,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
void *devPtrSoftmaxSum = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devPtrdQAccumulator = static_cast<int8_t *>(devPtrSoftmaxSum)
+ softmaxSum_workspace_size;
NVTE_CHECK_CUDA(cudaMemset(devPtrdQAccumulator, 0, dqAccum_workspace_size));
NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQAccumulator, 0, dqAccum_workspace_size, stream));
std::set<std::pair<uint64_t, void *>> data_ptrs;
// add all the data pointers to be used in the variant pack
......@@ -1224,6 +1224,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void* devPtrDropoutSeed = rng_state->data.dptr;
......@@ -1250,6 +1252,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
......@@ -1312,6 +1316,8 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
} // namespace transformer_engine
......
......@@ -1275,6 +1275,8 @@ void fused_attn_max_512_fwd_qkvpacked(
} else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void *devPtrCuSeqlen = cu_seqlens->data.dptr;
......@@ -1304,6 +1306,8 @@ void fused_attn_max_512_fwd_qkvpacked(
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
......@@ -1351,6 +1355,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
} else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
void *devQCuSeqlen = q_cu_seqlens->data.dptr;
......@@ -1380,6 +1386,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
......@@ -1440,6 +1448,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
......@@ -1503,6 +1513,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
}
} // namespace transformer_engine
......
......@@ -15,7 +15,7 @@ from jaxlib.hlo_helpers import custom_call
import jax.numpy as jnp
from jax.lib import xla_client
from jax import core, dtypes
from jax.abstract_arrays import ShapedArray
from jax.core import ShapedArray
from jax.interpreters import xla, mlir
from jax.interpreters.mlir import ir, dtype_to_ir_type
......@@ -23,6 +23,8 @@ import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend
for _name, _value in transformer_engine_jax.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
......@@ -1981,32 +1983,50 @@ def scaled_upper_triang_masked_softmax_bwd(grad_outputs: jnp.ndarray, softmax_ou
scale_factor=scale_factor)
def _check_seed(seed, dropout_probability, is_training):
@dataclass(frozen=True)
class _FusedAttnRNGStateChecker:
"""
Checker for guarding the fused attention rng state.
The fused attention backend requires a 64 bits seed and a 64 bits offset.
However, JAX doesn't enable 64 bits by default,
so we have to emulate seed as two 32 bits array.
The offset calculation is maintained in the backend.
"""
rng_state_dtype: jnp.dtype = jnp.uint32
# (seed,) with internal dtype int64
seed_size: int = 2
# (seed, offset) with internal dtype int64
rng_state_size: int = 2 * 2
def check_seed(self, seed, dropout_probability, is_training):
"""
Check the seed and convert the data type of seed if possible.
"""
# Jax can't bind None, create a dummy tensor for None
if seed is None:
dropout_enabled = dropout_probability > 0 and is_training
assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled."
seed = jnp.zeros(2, dtype=jnp.uint32)
seed = jnp.zeros(2, dtype=self.rng_state_dtype)
if seed.dtype != jnp.uint32:
if seed.dtype != self.rng_state_dtype:
warnings.warn(
f"Requested {seed.dtype=} is not available, and will be "
f"casted to dtype uint32. "
f"casted to dtype {self.rng_state_dtype}. "
f"Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning.")
seed = seed.astype(jnp.uint32)
seed = seed.astype(self.rng_state_dtype)
assert seed.dtype == jnp.uint32
# Only the first 2 u32 elements are taken
assert seed.size >= 2
assert seed.dtype == self.rng_state_dtype
# Backend takes an int64_t seed, so only the first two u32 elements are taken
assert seed.size >= self.seed_size
return seed
class SelfFusedAttnMax512FwdPrimitive(BasePrimitive):
class SelfFusedAttnFwdPrimitive(BasePrimitive):
"""
Self Fused Attention Max Seqlen 512 Forward Primitive
Self Fused Attention Forward Primitive
"""
name = "te_self_fused_attn_max_512_forward"
name = "te_self_fused_attn_forward"
multiple_results = True
@staticmethod
......@@ -2023,7 +2043,7 @@ class SelfFusedAttnMax512FwdPrimitive(BasePrimitive):
is_training # pylint: disable=unused-argument
):
"""
Self fused attention max seqlen 512 fwd abstract
Self fused attention fwd abstract
"""
qkv_dtype = dtypes.canonicalize_dtype(qkv.dtype)
batch, max_seqlen, nqkv, num_head, head_dim = qkv.shape
......@@ -2033,78 +2053,79 @@ class SelfFusedAttnMax512FwdPrimitive(BasePrimitive):
output_shape = (batch, max_seqlen, num_head, head_dim)
output_dtype = qkv_dtype
backend = transformer_engine_jax.get_fused_attn_backend(
jax_dtype_to_te_dtype(qkv_dtype), jax_dtype_to_te_dtype(qkv_dtype),
NVTE_QKV_Layout.NVTE_QKV_INTERLEAVED, attn_bias_type, attn_mask_type,
dropout_probability, max_seqlen, max_seqlen, head_dim)
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_aux_shape = (batch, num_head, max_seqlen, max_seqlen)
softmax_dtype = qkv_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_aux_shape = (batch, num_head, max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f'Not supported {backend=}')
checker = _FusedAttnRNGStateChecker()
seed_dtype = dtypes.canonicalize_dtype(seed.dtype)
assert seed_dtype == checker.rng_state_dtype
rng_state_shape = (checker.rng_state_size,)
rng_state_dtype = seed_dtype
return (
ShapedArray(output_shape, output_dtype, named_shape=qkv.named_shape), # output
ShapedArray(softmax_aux_shape, softmax_dtype,
named_shape=qkv.named_shape), # softmax_aux
ShapedArray(rng_state_shape, rng_state_dtype,
named_shape=seed.named_shape), # rng_state
)
@staticmethod
def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
"""
Self fused attention max seqlen 512 fwd lowering rules
Self fused attention fwd lowering rules
"""
qkv_aval, _, _, _ = ctx.avals_in
ir_qkv_type = ir.RankedTensorType(qkv.type)
ir_qkv_shape = ir_qkv_type.shape
ir_bias_type = ir.RankedTensorType(bias.type)
ir_bias_shape = ir_bias_type.shape
ir_cu_seqlen_type = ir.RankedTensorType(cu_seqlen.type)
ir_cu_seqlen_shape = ir_cu_seqlen_type.shape
ir_seed_type = ir.RankedTensorType(seed.type)
ir_seed_shape = ir_seed_type.shape
batch, max_seqlen, nqkv, num_head, head_dim = ir_qkv_shape
assert nqkv == 3
output_shape = (batch, max_seqlen, num_head, head_dim)
softmax_aux_shape = (batch, num_head, max_seqlen, max_seqlen)
batch, max_seqlen, _, num_head, head_dim = qkv_aval.shape
operands = [qkv, bias, cu_seqlen, seed]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output_shape, ir_qkv_type.element_type),
ir.RankedTensorType.get(softmax_aux_shape, ir_qkv_type.element_type)
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
operands = [qkv, bias, cu_seqlen, seed]
operand_shapes = [ir_qkv_shape, ir_bias_shape, ir_cu_seqlen_shape, ir_seed_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
out = custom_caller(SelfFusedAttnMax512FwdPrimitive.name,
args,
opaque,
has_side_effect=False)
out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out
_self_fused_attn_max_512_fwd_p = register_primitive(SelfFusedAttnMax512FwdPrimitive)
_self_fused_attn_fwd_p = register_primitive(SelfFusedAttnFwdPrimitive)
def self_fused_attn_max_512_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, cu_seqlen: jnp.ndarray,
def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, cu_seqlen: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
Wrapper for TE self fused attention max seqlen 512 fwd
Wrapper for TE self fused attention fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
"""
seed = _check_seed(seed, dropout_probability, is_training)
checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
if bias is None:
assert attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=qkv.dtype)
return _self_fused_attn_max_512_fwd_p.bind(qkv,
return _self_fused_attn_fwd_p.bind(qkv,
bias,
cu_seqlen,
seed,
......@@ -2115,17 +2136,19 @@ def self_fused_attn_max_512_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, cu_seqlen:
is_training=is_training)
class SelfFusedAttnMax512BwdPrimitive(BasePrimitive):
class SelfFusedAttnBwdPrimitive(BasePrimitive):
"""
Self Fused Attention Max Seqlen 512 Backward Primitive
Self Fused Attention Backward Primitive
"""
name = "te_self_fused_attn_max_512_backward"
name = "te_self_fused_attn_backward"
multiple_results = True
@staticmethod
def abstract(
qkv,
softmax_aux,
softmax_aux, # pylint: disable=unused-argument
rng_state, # pylint: disable=unused-argument
output, # pylint: disable=unused-argument
doutput,
cu_seqlen, # pylint: disable=unused-argument
*,
......@@ -2139,10 +2162,13 @@ class SelfFusedAttnMax512BwdPrimitive(BasePrimitive):
Self fused attention bwd abstract
"""
qkv_dtype = dtypes.canonicalize_dtype(qkv.dtype)
assert qkv.dtype == softmax_aux.dtype == doutput.dtype
assert qkv.dtype == doutput.dtype
_, seqlen, _, num_head, _ = qkv.shape
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_shape = (0,)
else:
bias_shape = (1, num_head, seqlen, seqlen)
bias_dtype = qkv_dtype
......@@ -2151,66 +2177,48 @@ class SelfFusedAttnMax512BwdPrimitive(BasePrimitive):
ShapedArray(bias_shape, bias_dtype, named_shape=qkv.named_shape))
@staticmethod
def lowering(ctx, qkv, softmax_aux, doutput, cu_seqlen, *, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
def lowering(ctx, qkv, softmax_aux, rng_state, output, doutput, cu_seqlen, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Self fused attention max seqlen 512 bwd lowering rules
Self fused attention bwd lowering rules
"""
qkv_aval, _, _, _ = ctx.avals_in
ir_qkv_type = ir.RankedTensorType(qkv.type)
ir_qkv_shape = ir_qkv_type.shape
ir_softmax_aux_type = ir.RankedTensorType(softmax_aux.type)
ir_softmax_aux_shape = ir_softmax_aux_type.shape
ir_doutput_type = ir.RankedTensorType(doutput.type)
ir_doutput_shape = ir_doutput_type.shape
qkv_aval, _, _, _, _, _ = ctx.avals_in
ir_cu_seqlen_type = ir.RankedTensorType(cu_seqlen.type)
ir_cu_seqlen_shape = ir_cu_seqlen_type.shape
batch, max_seqlen, num_head, head_dim = ir_doutput_shape
dbias_shape = (1, num_head, max_seqlen, max_seqlen)
dbias_dtype = ir_qkv_type.element_type
batch, max_seqlen, _, num_head, head_dim = qkv_aval.shape
operands = [qkv, softmax_aux, rng_state, output, doutput, cu_seqlen]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(ir_qkv_shape, ir_qkv_type.element_type),
ir.RankedTensorType.get(dbias_shape, dbias_dtype)
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
operands = [qkv, softmax_aux, doutput, cu_seqlen]
operand_shapes = [ir_qkv_shape, ir_softmax_aux_shape, ir_doutput_shape, ir_cu_seqlen_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
# the dropout elements are encoded in the forward auxiliary tensor
# so seed is not needed in backward
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
out = custom_caller(SelfFusedAttnMax512BwdPrimitive.name,
args,
opaque,
has_side_effect=False)
out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out
_self_fused_attn_max_512_bwd_p = register_primitive(SelfFusedAttnMax512BwdPrimitive)
_self_fused_attn_bwd_p = register_primitive(SelfFusedAttnBwdPrimitive)
def self_fused_attn_max_512_bwd(qkv: jnp.ndarray, softmax_aux: jnp.ndarray, doutput: jnp.ndarray,
cu_seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
def self_fused_attn_bwd(qkv: jnp.ndarray, softmax_aux: jnp.ndarray, rng_state: jnp.ndarray,
output: jnp.ndarray, doutput: jnp.ndarray, cu_seqlen: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Wrapper for TE self fused attention max seqlen 512 bwd
Wrapper for TE self fused attention bwd
Return the gradients of self fused attention with packed qkv input
"""
return _self_fused_attn_max_512_bwd_p.bind(qkv,
return _self_fused_attn_bwd_p.bind(qkv,
softmax_aux,
rng_state,
output,
doutput,
cu_seqlen,
attn_bias_type=attn_bias_type,
......@@ -2220,11 +2228,11 @@ def self_fused_attn_max_512_bwd(qkv: jnp.ndarray, softmax_aux: jnp.ndarray, dout
is_training=is_training)
class CrossFusedAttnMax512FwdPrimitive(BasePrimitive):
class CrossFusedAttnFwdPrimitive(BasePrimitive):
"""
Cross Fused Attention Forward Max Seqlen 512 Primitive
Cross Fused Attention Forward Primitive
"""
name = "te_cross_fused_attn_max_512_forward"
name = "te_cross_fused_attn_forward"
multiple_results = True
@staticmethod
......@@ -2242,7 +2250,7 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive):
is_training # pylint: disable=unused-argument
):
"""
Cross fused attention max seqlen 512 fwd abstract
Cross fused attention fwd abstract
"""
q_dtype = dtypes.canonicalize_dtype(q.dtype)
batch_q, q_max_seqlen, num_head_q, head_dim_q = q.shape
......@@ -2271,36 +2279,19 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive):
def lowering(ctx, q, kv, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
"""
Cross fused attention max seqlen 512 fwd lowering rules
Cross fused attention fwd lowering rules
"""
q_aval, kv_aval, _, _, _ = ctx.avals_in
assert q_aval.dtype == kv_aval.dtype
ir_q_type = ir.RankedTensorType(q.type)
ir_q_shape = ir_q_type.shape
ir_kv_type = ir.RankedTensorType(kv.type)
ir_kv_shape = ir_kv_type.shape
ir_q_cu_seqlen_shape = ir.RankedTensorType(q_cu_seqlen.type).shape
ir_kv_cu_seqlen_shape = ir.RankedTensorType(kv_cu_seqlen.type).shape
batch, q_max_seqlen, num_head, head_dim = q_aval.shape
kv_max_seqlen = kv_aval.shape[1]
ir_seed_type = ir.RankedTensorType(seed.type)
ir_seed_shape = ir_seed_type.shape
batch, q_max_seqlen, num_head, head_dim = ir_q_shape
kv_max_seqlen = ir_kv_shape[1]
output_shape = (batch, q_max_seqlen, num_head, head_dim)
softmax_aux_shape = (batch, num_head, q_max_seqlen, kv_max_seqlen)
out_types = [
ir.RankedTensorType.get(output_shape, ir_q_type.element_type),
ir.RankedTensorType.get(softmax_aux_shape, ir_q_type.element_type)
]
operands = [q, kv, q_cu_seqlen, kv_cu_seqlen, seed]
operand_shapes = [
ir_q_shape, ir_kv_shape, ir_q_cu_seqlen_shape, ir_kv_cu_seqlen_shape, ir_seed_shape
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
......@@ -2309,29 +2300,26 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive):
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
out = custom_caller(CrossFusedAttnMax512FwdPrimitive.name,
args,
opaque,
has_side_effect=False)
out = custom_caller(CrossFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out
_cross_fused_attn_max_512_fwd_p = register_primitive(CrossFusedAttnMax512FwdPrimitive)
_cross_fused_attn_fwd_p = register_primitive(CrossFusedAttnFwdPrimitive)
def cross_fused_attn_max_512_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_cu_seqlen: jnp.ndarray,
def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_cu_seqlen: jnp.ndarray,
kv_cu_seqlen: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float,
is_training: bool):
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Wrapper for TE cross fused attention max seqlen 512 fwd
Wrapper for TE cross fused attention fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
"""
seed = _check_seed(seed, dropout_probability, is_training)
checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
return _cross_fused_attn_max_512_fwd_p.bind(q,
return _cross_fused_attn_fwd_p.bind(q,
kv,
q_cu_seqlen,
kv_cu_seqlen,
......@@ -2343,11 +2331,11 @@ def cross_fused_attn_max_512_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_cu_seqlen: j
is_training=is_training)
class CrossFusedAttnMax512BwdPrimitive(BasePrimitive):
class CrossFusedAttnBwdPrimitive(BasePrimitive):
"""
Cross Fused Attention Max Seqlen 512 Backward Primitive
Cross Fused Attention Backward Primitive
"""
name = "te_cross_fused_attn_max_512_backward"
name = "te_cross_fused_attn_backward"
multiple_results = True
@staticmethod
......@@ -2366,7 +2354,7 @@ class CrossFusedAttnMax512BwdPrimitive(BasePrimitive):
is_training # pylint: disable=unused-argument
):
"""
Cross fused attention max seqlen 512 bwd abstract
Cross fused attention bwd abstract
"""
q_dtype = dtypes.canonicalize_dtype(q.dtype)
kv_dtype = dtypes.canonicalize_dtype(kv.dtype)
......@@ -2384,34 +2372,19 @@ class CrossFusedAttnMax512BwdPrimitive(BasePrimitive):
def lowering(ctx, q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Cross fused attention max seqlen 512 bwd lowering rules
Cross fused attention bwd lowering rules
"""
q_aval, _, _, _, _, _ = ctx.avals_in
ir_q_type = ir.RankedTensorType(q.type)
ir_q_shape = ir_q_type.shape
ir_kv_type = ir.RankedTensorType(kv.type)
ir_kv_shape = ir_kv_type.shape
ir_softmax_aux_type = ir.RankedTensorType(softmax_aux.type)
ir_softmax_aux_shape = ir_softmax_aux_type.shape
ir_doutput_shape = ir.RankedTensorType(doutput.type).shape
ir_q_cu_seqlen_shape = ir.RankedTensorType(q_cu_seqlen.type).shape
ir_kv_cu_seqlen_shape = ir.RankedTensorType(kv_cu_seqlen.type).shape
q_aval, kv_aval, _, _, _, _ = ctx.avals_in
assert q_aval.dtype == kv_aval.dtype
batch, q_max_seqlen, num_head, head_dim = ir_doutput_shape
kv_max_seqlen = ir_kv_shape[1]
batch, q_max_seqlen, num_head, head_dim = q_aval.shape
kv_max_seqlen = kv_aval.shape[1]
out_types = [
ir.RankedTensorType.get(ir_q_shape, ir_q_type.element_type),
ir.RankedTensorType.get(ir_kv_shape, ir_kv_type.element_type),
]
operands = [q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen]
operand_shapes = [
ir_q_shape, ir_kv_shape, ir_softmax_aux_shape, ir_doutput_shape, ir_q_cu_seqlen_shape,
ir_kv_cu_seqlen_shape
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
......@@ -2423,27 +2396,23 @@ class CrossFusedAttnMax512BwdPrimitive(BasePrimitive):
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
out = custom_caller(CrossFusedAttnMax512BwdPrimitive.name,
args,
opaque,
has_side_effect=False)
out = custom_caller(CrossFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out
_cross_fused_attn_max_512_bwd_p = register_primitive(CrossFusedAttnMax512BwdPrimitive)
_cross_fused_attn_bwd_p = register_primitive(CrossFusedAttnBwdPrimitive)
def cross_fused_attn_max_512_bwd(q: jnp.ndarray, kv: jnp.ndarray, softmax_aux: jnp.ndarray,
doutput: jnp.ndarray, q_cu_seqlen: jnp.ndarray,
kv_cu_seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, softmax_aux: jnp.ndarray,
doutput: jnp.ndarray, q_cu_seqlen: jnp.ndarray, kv_cu_seqlen: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Wrapper for TE cross fused attention max seqlen 512 bwd
Wrapper for TE cross fused attention bwd
Return the gradients of cross fused attention with packed kv input
"""
return _cross_fused_attn_max_512_bwd_p.bind(q,
return _cross_fused_attn_bwd_p.bind(q,
kv,
softmax_aux,
doutput,
......
......@@ -46,11 +46,10 @@ pybind11::dict Registrations() {
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward);
dict["te_scaled_upper_triang_masked_softmax_backward"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward);
dict["te_self_fused_attn_max_512_forward"] = EncapsulateFunction(SelfFusedAttnMax512Forward);
dict["te_self_fused_attn_max_512_backward"] = EncapsulateFunction(SelfFusedAttnMax512Backward);
dict["te_cross_fused_attn_max_512_forward"] = EncapsulateFunction(CrossFusedAttnMax512Forward);
dict["te_cross_fused_attn_max_512_backward"] =
EncapsulateFunction(CrossFusedAttnMax512Backward);
dict["te_self_fused_attn_forward"] = EncapsulateFunction(SelfFusedAttnForward);
dict["te_self_fused_attn_backward"] = EncapsulateFunction(SelfFusedAttnBackward);
dict["te_cross_fused_attn_forward"] = EncapsulateFunction(CrossFusedAttnForward);
dict["te_cross_fused_attn_backward"] = EncapsulateFunction(CrossFusedAttnBackward);
return dict;
}
......@@ -65,6 +64,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
m.def("is_fused_attn_kernel_available", &IsFusedAttnKernelAvailable);
m.def("get_fused_attn_backend", &GetFusedAttnBackend);
pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte)
......@@ -85,6 +85,17 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK);
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local())
.value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED)
.value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
.value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED);
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend)
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8);
}
} // namespace jax
......
......@@ -740,7 +740,18 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
desc.scale_factor, stream);
}
void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque,
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim) {
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim);
return backend;
}
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
......@@ -754,12 +765,17 @@ void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
// output
void *output = buffers[4];
void *softmax_aux = buffers[5];
void *rng_state = buffers[6];
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED;
NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
......@@ -768,58 +784,56 @@ void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
// input tensors
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
// FP16/BF16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
// output tensors
auto o_tensor =
TensorWrapper(output, std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}, dtype);
auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
// aux tensors
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, cu_seqlens_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
descriptor.is_training, descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type,
query_workspace_tensor.data(), stream);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), stream);
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux;
// fused attn workspace + workspace for rng_state
auto plan_workspace_size =
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto rng_workspace_size = 2 * sizeof(int64_t);
auto total_workspace_size = plan_workspace_size + rng_workspace_size;
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(total_workspace_size);
auto workspace_size = query_workspace_tensor.shape().data[0];
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
auto rng_state = static_cast<uint8_t *>(workspace) + plan_workspace_size;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, stream);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type,
descriptor.mask_type, workspace_tensor.data(), stream);
descriptor.scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque,
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
......@@ -827,19 +841,24 @@ void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char
// input
void *qkv = buffers[0];
void *softmax_aux = buffers[1];
void *doutput = buffers[2];
void *cu_seqlens = buffers[3];
void *rng_state = buffers[2];
void *output = buffers[3];
void *doutput = buffers[4];
void *cu_seqlens = buffers[5];
// output
void *dqkv = buffers[4];
void *dp = softmax_aux;
void *dbias = buffers[5];
void *dqkv = buffers[6];
void *dbias = buffers[7];
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED;
NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
......@@ -850,11 +869,9 @@ void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// It's a little trick that the flash attn needs fwd output
// But when seqlen <= 512, it is not needed
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
// FP16/BF16 doesn't use this tensor
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
......@@ -862,49 +879,46 @@ void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char
auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
// Currently, no rng_state required for bwd
auto rng_state = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt64);
// TODO: needs to think about how to pass aux_output_tensors
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
aux_output_tensors.size = 1;
aux_output_tensors.size = 2;
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.shape = std::vector<size_t>{batch, num_head, q_max_seqlen, kv_max_seqlen};
output_s->data.dptr = softmax_aux;
auto *rng_state_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[1]);
rng_state_tensor->data.shape = std::vector<size_t>{2};
rng_state_tensor->data.dtype = DType::kInt64;
rng_state_tensor->data.dptr = rng_state;
TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor,
descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type,
descriptor.mask_type, query_workspace_tensor.data(), stream);
dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), stream);
size_t workspace_size =
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
size_t workspace_size = query_workspace_tensor.shape().data[0];
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor,
descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type,
descriptor.mask_type, workspace_tensor.data(), stream);
dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque,
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
......@@ -925,6 +939,10 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_KV_INTERLEAVED;
auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
......@@ -958,8 +976,7 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type,
descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), stream);
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
......@@ -976,20 +993,23 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
auto rng_state = static_cast<uint8_t *>(workspace) + plan_workspace_size;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, stream);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type,
descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
void CrossFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque,
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
......
......@@ -116,6 +116,12 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
bool IsFusedAttnKernelAvailable();
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim);
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......@@ -166,16 +172,16 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque,
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque,
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque,
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void CrossFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque,
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
} // namespace jax
......
......@@ -41,9 +41,15 @@ __global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t
}
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, cudaStream_t stream) {
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream) {
size_t increment = 0;
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
increment = 16;
} else {
constexpr int threads_per_cta = 128;
const size_t increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta;
increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta;
}
auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment);
populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast<int64_t *>(rng_state_dst),
reinterpret_cast<const int64_t *>(seed), offset);
......
......@@ -13,6 +13,7 @@
#include <stdexcept>
#include <string>
#include <type_traits>
#include "transformer_engine/fused_attn.h"
#include "transformer_engine/logging.h"
namespace transformer_engine {
......@@ -22,7 +23,8 @@ int GetCudaRuntimeVersion();
int GetDeviceComputeCapability(int gpu_id);
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, cudaStream_t stream);
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream);
class cublasLtMetaManager {
public:
......
......@@ -178,7 +178,8 @@ def core_attention(query: Array,
attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor,
sharding_type=softmax_sharding_type)(attn_weights, mask, bias)
sharding_type=softmax_sharding_type)(attn_weights, mask,
bias).astype(dtype)
if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate
......@@ -369,12 +370,20 @@ class MultiHeadAttention(nn.Module):
canonicalize_dtype = dtypes.canonicalize_dtype(self.dtype)
q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1]
kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1]
fused_attn_supported_seqlen = [128, 256, 384, 512]
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))
def _check_seqlen(seqlen):
return seqlen % 64 == 0
def _check_head_dim(head_dim):
return head_dim in [64, 128]
use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \
q_seqlen in fused_attn_supported_seqlen and kv_seqlen in fused_attn_supported_seqlen \
and is_fused_attn_kernel_available() and (self.head_dim == 64) and enable_fused_attn
_check_seqlen(q_seqlen) and _check_seqlen(kv_seqlen) and \
_check_head_dim(self.head_dim) and \
is_fused_attn_kernel_available() and \
enable_fused_attn
if enable_fused_attn and not use_fused_attn:
reason = ""
......@@ -388,16 +397,16 @@ class MultiHeadAttention(nn.Module):
if canonicalize_dtype not in [jnp.bfloat16, jnp.float16]:
reason += f"dtype in [BF16, FP16] is required " \
f"but got dtype={canonicalize_dtype}, "
if q_seqlen not in fused_attn_supported_seqlen:
reason += f"q_seqlen in {fused_attn_supported_seqlen} is required " \
if not _check_seqlen(q_seqlen):
reason += f"q_seqlen % 64 == 0 is required " \
f"but got {q_seqlen=}, "
if kv_seqlen not in fused_attn_supported_seqlen:
reason += f"kv_seqlen in {fused_attn_supported_seqlen} is required " \
if not _check_seqlen(kv_seqlen):
reason += f"kv_seqlen % 64 == 0 is required " \
f"but got {kv_seqlen=}, "
if not _check_head_dim(self.head_dim):
reason += f"head_dim should be 64 or 128 but got {self.head_dim}, "
if not is_fused_attn_kernel_available():
reason += "GPU arch >= Ampere and cuDNN >= 8.9.1 are required, "
if self.head_dim != 64:
reason += f"head_dim should be 64 but got {self.head_dim}, "
warnings.warn(
f"Fused attention is not enabled, " \
......
......@@ -12,8 +12,8 @@ import transformer_engine_jax
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from .cpp_extensions import cross_fused_attn_max_512_fwd, cross_fused_attn_max_512_bwd
from .cpp_extensions import self_fused_attn_max_512_fwd, self_fused_attn_max_512_bwd
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
from .sharding import get_fused_attn_sharding_meta
from .sharding import ShardingType
from .sharding import xmap_runner
......@@ -57,10 +57,10 @@ def self_fused_attn(qkv: jnp.ndarray,
Self fused attention wrapper
"""
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"Fused_attn_max_512 does not support row-split tensor parallelism currently."
"self_fused_attn does not support row-split tensor parallelism currently."
if sharding_type is ShardingType.SINGLE:
output = _self_fused_attn_max_512(qkv,
output = _self_fused_attn(qkv,
bias,
mask,
seed,
......@@ -87,14 +87,14 @@ def self_fused_attn(qkv: jnp.ndarray,
jnp.reshape(x, new_shape) if x is not None else None
for x, new_shape in zip(inputs, sharding_meta.input_shapes))
partial_self_fused_attn_max_512 = partial(_self_fused_attn_max_512,
partial_self_fused_attn = partial(_self_fused_attn,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output_ = xmap_runner(partial_self_fused_attn_max_512, sharding_meta.in_axes,
output_ = xmap_runner(partial_self_fused_attn, sharding_meta.in_axes,
sharding_meta.out_axes[0], sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes[0])
......@@ -103,11 +103,10 @@ def self_fused_attn(qkv: jnp.ndarray,
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _self_fused_attn_max_512(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
output, _ = _self_fused_attn_max_512_fwd(qkv,
def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
output, _ = _self_fused_attn_fwd(qkv,
bias,
mask,
seed,
......@@ -119,14 +118,14 @@ def _self_fused_attn_max_512(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndar
return output
def _self_fused_attn_max_512_fwd(qkv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
def _self_fused_attn_fwd(qkv, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
cu_seqlen = jnp.cumsum(seqlen)
cu_seqlen = jnp.hstack((0, cu_seqlen))
output, softmax_aux = self_fused_attn_max_512_fwd(qkv,
output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
bias,
cu_seqlen,
seed,
......@@ -135,17 +134,19 @@ def _self_fused_attn_max_512_fwd(qkv, bias, mask, seed, attn_bias_type, attn_mas
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output, (qkv, softmax_aux, cu_seqlen)
return output, (qkv, softmax_aux, rng_state, output, cu_seqlen)
def _self_fused_attn_max_512_bwd(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, ctx, grad):
qkv, softmax_aux, cu_seqlen = ctx
def _self_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, grad):
qkv, softmax_aux, rng_state, output, cu_seqlen = ctx
doutput = grad
grad_qkv, grad_bias = self_fused_attn_max_512_bwd(qkv,
grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
softmax_aux,
rng_state,
output,
doutput,
cu_seqlen,
attn_bias_type=attn_bias_type.value,
......@@ -154,10 +155,13 @@ def _self_fused_attn_max_512_bwd(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
grad_bias = None
return grad_qkv, grad_bias, None, None
_self_fused_attn_max_512.defvjp(_self_fused_attn_max_512_fwd, _self_fused_attn_max_512_bwd)
_self_fused_attn.defvjp(_self_fused_attn_fwd, _self_fused_attn_bwd)
def cross_fused_attn(q: jnp.ndarray,
......@@ -174,10 +178,10 @@ def cross_fused_attn(q: jnp.ndarray,
Cross multi-head attention wrapper
"""
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"Fused_attn_max_512 does not support row-split tensor parallelism currently."
"cross_fused_attn does not support row-split tensor parallelism currently."
if sharding_type is ShardingType.SINGLE:
output = _cross_fused_attn_max_512(q,
output = _cross_fused_attn(q,
kv,
mask,
seed,
......@@ -203,14 +207,14 @@ def cross_fused_attn(q: jnp.ndarray,
jnp.reshape(x, new_shape) if x is not None else None
for x, new_shape in zip(inputs, sharding_meta.input_shapes))
partial_cross_fused_attn_max_512 = partial(_cross_fused_attn_max_512,
partial_cross_fused_attn = partial(_cross_fused_attn,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output_ = xmap_runner(partial_cross_fused_attn_max_512, sharding_meta.in_axes,
output_ = xmap_runner(partial_cross_fused_attn, sharding_meta.in_axes,
sharding_meta.out_axes[0], sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes[0])
......@@ -219,11 +223,11 @@ def cross_fused_attn(q: jnp.ndarray,
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _cross_fused_attn_max_512(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
output, _ = _cross_fused_attn_max_512_fwd(q,
output, _ = _cross_fused_attn_fwd(q,
kv,
mask,
seed,
......@@ -235,7 +239,7 @@ def _cross_fused_attn_max_512(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray
return output
def _cross_fused_attn_max_512_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
def _cross_fused_attn_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
q_seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
......@@ -246,7 +250,7 @@ def _cross_fused_attn_max_512_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_t
kv_cu_seqlen = jnp.cumsum(kv_seqlen)
kv_cu_seqlen = jnp.hstack((0, kv_cu_seqlen))
output, softmax_aux = cross_fused_attn_max_512_fwd(q,
output, softmax_aux = cross_fused_attn_fwd(q,
kv,
q_cu_seqlen,
kv_cu_seqlen,
......@@ -259,13 +263,13 @@ def _cross_fused_attn_max_512_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_t
return output, (softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen)
def _cross_fused_attn_max_512_bwd(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, ctx, grad):
def _cross_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, grad):
softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen = ctx
doutput = grad
grad_q, grad_kv = cross_fused_attn_max_512_bwd(q,
grad_q, grad_kv = cross_fused_attn_bwd(q,
kv,
softmax_aux,
doutput,
......@@ -280,4 +284,4 @@ def _cross_fused_attn_max_512_bwd(attn_bias_type, attn_mask_type, scaling_factor
return grad_q, grad_kv, None, None
_cross_fused_attn_max_512.defvjp(_cross_fused_attn_max_512_fwd, _cross_fused_attn_max_512_bwd)
_cross_fused_attn.defvjp(_cross_fused_attn_fwd, _cross_fused_attn_bwd)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment