Unverified Commit 66ff2e36 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX] flash attention integration (#345)



* Fix flash attention dropout probability with inference
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add output as the fused attention ctx tensor
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add rng_state as the fused attention ctx tensors
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add flash attention supported lengths to the fused attention
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refactor attention primitive to reuse abstract shaped array
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Detect backend type to allocate appropriate ctx size
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Skip dropout correctness instead of return success
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use cudaMemsetAsync and enhance the error handling
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add flash attention kernel elts_per_thread update
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove redundant max 512 suffix
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Keep only DType and remove NVTEDType from python
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix a float32_attention_logits bugs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Re-calculate workspace size for self attention
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance bias/dbias shape guard
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance the seed/rng_state checker
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use jax.core.ShapedArray as jax.abstract_arrays is deprecated
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance the unittest docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 403ade2f
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import pytest import pytest
import jax.numpy as jnp import jax.numpy as jnp
from jax.abstract_arrays import ShapedArray from jax.core import ShapedArray
from transformer_engine_jax import DType from transformer_engine_jax import DType
from transformer_engine.jax.cpp_extensions import te_dtype_to_jax_dtype from transformer_engine.jax.cpp_extensions import te_dtype_to_jax_dtype
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Tests for fused attention"""
from typing import Optional import os
import math from enum import Enum
from math import sqrt
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -14,8 +16,6 @@ from flax.linen import combine_masks ...@@ -14,8 +16,6 @@ from flax.linen import combine_masks
from flax.linen import dot_product_attention from flax.linen import dot_product_attention
from flax.linen import make_attention_mask from flax.linen import make_attention_mask
from flax.linen import make_causal_mask from flax.linen import make_causal_mask
from jax import lax
from jax import nn as jax_nn
from jax import value_and_grad, jit from jax import value_and_grad, jit
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType
...@@ -25,19 +25,45 @@ from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn ...@@ -25,19 +25,45 @@ from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
# Type annotations # Type annotations
Array = jnp.ndarray Array = jnp.ndarray
SELF_CASES = [(32, 512, 16, 64), (32, 128, 16, 64)]
class Backend(Enum):
"""
Fused attn backend.
Unit tests only, transformer will auto dispatch to the best backend
"""
Max512 = "0"
Arbitrary = "1"
@pytest.fixture(name="backend", params=[Backend.Max512, Backend.Arbitrary])
def fixture_backend(request):
"""
Fixture of setting up/tearing down backend
"""
backend = request.param
os.environ["NVTE_FUSED_ATTN_BACKEND"] = backend.value
yield backend
os.environ["NVTE_FUSED_ATTN_BACKEND"] = ""
SELF_CASES = [(32, 512, 16, 64), (32, 128, 16, 64), (4, 2048, 12, 64)]
CROSS_CASES = [(32, 128, 512, 16, 64)] CROSS_CASES = [(32, 128, 512, 16, 64)]
DTYPES = [jnp.bfloat16, jnp.float16] DTYPES = [jnp.bfloat16, jnp.float16]
PAD_RATIO = [0.3]
def make_decoder_mask(tokens: Array) -> Array: def make_decoder_mask(tokens: Array) -> Array:
"""
Create padded causal mask
"""
causal_mask = make_causal_mask(tokens) causal_mask = make_causal_mask(tokens)
padding_mask = make_attention_mask(tokens > 0, tokens > 0) padding_mask = make_attention_mask(tokens > 0, tokens > 0)
return combine_masks(causal_mask, padding_mask) return combine_masks(causal_mask, padding_mask)
def jax_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs): def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
"""
Self attention with JAX native implementation
"""
attn_mask_type = kwargs['attn_mask_type'] attn_mask_type = kwargs['attn_mask_type']
if attn_mask_type == AttnMaskType.CAUSAL_MASK: if attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_decoder_mask(q_token) mask = make_decoder_mask(q_token)
...@@ -61,7 +87,10 @@ def jax_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs): ...@@ -61,7 +87,10 @@ def jax_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
return output return output
def jax_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs): def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
"""
Cross attention with JAX native implementation
"""
assert q.dtype == kv.dtype assert q.dtype == kv.dtype
attn_mask_type = kwargs['attn_mask_type'] attn_mask_type = kwargs['attn_mask_type']
...@@ -87,6 +116,9 @@ def jax_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs): ...@@ -87,6 +116,9 @@ def jax_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
def customcall_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs): def customcall_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
"""
Self fused attention
"""
if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK: if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK:
mask = make_decoder_mask(q_token) mask = make_decoder_mask(q_token)
else: else:
...@@ -99,6 +131,9 @@ def customcall_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwar ...@@ -99,6 +131,9 @@ def customcall_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwar
def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs): def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
"""
Cross fused attention
"""
assert q.dtype == kv.dtype assert q.dtype == kv.dtype
if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK: if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK:
...@@ -113,10 +148,33 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs) ...@@ -113,10 +148,33 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)
@pytest.mark.skipif(not is_fused_attn_kernel_available(), @pytest.mark.skipif(not is_fused_attn_kernel_available(),
reason="Fused attention kernel is not supported.") reason="Fused attention kernel is not supported.")
class TestSelfFusedAttnMax512(): @pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
def set_input(self, b, s, h, d, *, attn_bias_type, attn_mask_type, dropout_probability, dtype, @pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
is_training, pad_ratio): @pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', [0, 0.3])
class TestSelfFusedAttn():
"""Tests for transformer_engine.jax.fused_attn.self_fused_attn"""
@staticmethod
def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, pad_ratio):
# Arbitrary seqlen backend has a limited spec for now
# No bias, only causal mask, and no variable seqlen
if (s > 512 or backend == Backend.Arbitrary) and (attn_bias_type != AttnBiasType.NO_BIAS or
attn_mask_type != AttnMaskType.CAUSAL_MASK
or pad_ratio != 0):
pytest.skip("Unsupported inputs combination.")
def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
dropout_probability, dtype, is_training, pad_ratio):
"""Setup the test inputs"""
self.__class__._check_inputs(s,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
backend=backend,
pad_ratio=pad_ratio)
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
...@@ -137,82 +195,29 @@ class TestSelfFusedAttnMax512(): ...@@ -137,82 +195,29 @@ class TestSelfFusedAttnMax512():
axis=-1) axis=-1)
self.kv_token = self.q_token self.kv_token = self.q_token
self.scaling_factor = 1. / math.sqrt(d) self.scaling_factor = 1. / sqrt(d)
self.dropout_probability = dropout_probability self.dropout_probability = dropout_probability
self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None
self.attn_bias_type = attn_bias_type self.attn_bias_type = attn_bias_type
self.attn_mask_type = attn_mask_type
self.is_training = is_training self.is_training = is_training
@pytest.mark.parametrize('b, s, h, d', SELF_CASES) def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, dropout_probability,
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS]) dtype, is_training, pad_ratio):
@pytest.mark.parametrize('attn_mask_type', """
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) Test forward without using JIT
@pytest.mark.parametrize('dropout_probability', [0., 0.1]) """
@pytest.mark.parametrize('dtype', DTYPES) self._set_inputs(b,
@pytest.mark.parametrize('is_training', [True, False]) s,
@pytest.mark.parametrize('pad_ratio', PAD_RATIO) h,
def test_sanity(self, b, s, h, d, attn_bias_type, attn_mask_type, dropout_probability, dtype, d,
is_training, pad_ratio): attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
def grad_func(func, *args, **kwargs): backend=backend,
# Keep only valid result for the gradient dropout_probability=dropout_probability,
# fused_attn_max_512 output has shape (b, s, h, d) dtype=dtype,
valid_ret, _ = jnp.split(func(*args, **kwargs), (self.valid_len,), axis=1) is_training=is_training,
return jnp.mean(valid_ret, dtype=jnp.float32).astype(dtype) pad_ratio=pad_ratio)
self.set_input(b,
s,
h,
d,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
kwargs = {
'attn_bias_type': self.attn_bias_type,
'attn_mask_type': attn_mask_type,
'scaling_factor': self.scaling_factor,
'dropout_probability': self.dropout_probability,
'is_training': self.is_training
}
jitted_primitive = jit(
value_and_grad(
lambda qkv, bias, q_token, kv_token, dropout_rng: grad_func(
customcall_self_fused_attn, qkv, bias, q_token, kv_token, dropout_rng, **kwargs
), (0, 1)))
primitive_out, (primitive_dqkv,
primitive_dbias) = jitted_primitive(self.qkv, self.bias, self.q_token,
self.kv_token, self.dropout_rng)
@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
# dropout can't get the bitmatch result
if is_training and dropout_probability > 0.:
return
self.set_input(b,
s,
h,
d,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
primitive_out = customcall_self_fused_attn(self.qkv, primitive_out = customcall_self_fused_attn(self.qkv,
self.bias, self.bias,
...@@ -225,19 +230,23 @@ class TestSelfFusedAttnMax512(): ...@@ -225,19 +230,23 @@ class TestSelfFusedAttnMax512():
dropout_probability=self.dropout_probability, dropout_probability=self.dropout_probability,
is_training=self.is_training) is_training=self.is_training)
reference_out = jax_self_fused_attn(self.qkv, reference_out = jax_self_attn(self.qkv,
self.bias, self.bias,
self.q_token, self.q_token,
self.kv_token, self.kv_token,
self.dropout_rng, self.dropout_rng,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor, scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability, dropout_probability=self.dropout_probability,
is_training=self.is_training) is_training=self.is_training)
ref_valid, _ = jnp.split(reference_out, (self.valid_len,), axis=1) ref_valid, _ = jnp.split(reference_out, (self.valid_len,), axis=1)
pri_valid, pri_invalid = jnp.split(primitive_out, (self.valid_len,), axis=1) pri_valid, pri_invalid = jnp.split(primitive_out, (self.valid_len,), axis=1)
# Dropout can't get the bitmatch result, skip the elementwise comparison
if is_training and dropout_probability > 0.:
return
np.testing.assert_allclose(jnp.asarray(pri_valid, np.float32), np.testing.assert_allclose(jnp.asarray(pri_valid, np.float32),
jnp.asarray(ref_valid, np.float32), jnp.asarray(ref_valid, np.float32),
rtol=1e-4, rtol=1e-4,
...@@ -246,38 +255,36 @@ class TestSelfFusedAttnMax512(): ...@@ -246,38 +255,36 @@ class TestSelfFusedAttnMax512():
np.testing.assert_allclose(jnp.asarray(pri_invalid, jnp.float32), np.testing.assert_allclose(jnp.asarray(pri_invalid, jnp.float32),
jnp.zeros_like(pri_invalid, jnp.float32)) jnp.zeros_like(pri_invalid, jnp.float32))
@pytest.mark.parametrize('b, s, h, d', SELF_CASES) def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend,
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS]) dropout_probability, dtype, is_training, pad_ratio):
@pytest.mark.parametrize('attn_mask_type', """
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) Test forward, backward, and autodiff by jax.value_and_grad
@pytest.mark.parametrize('dropout_probability', [0.]) # dropout can't get the bitmatch result """
@pytest.mark.parametrize('dtype', DTYPES) if not is_training:
@pytest.mark.parametrize('is_training', [True]) # backward is only used when is_training pytest.skip(f"Backward doesn't support {is_training=}")
@pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, dropout_probability, self._set_inputs(b,
dtype, is_training, pad_ratio): s,
self.set_input(b, h,
s, d,
h, attn_bias_type=attn_bias_type,
d, attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type, backend=backend,
attn_mask_type=attn_mask_type, dropout_probability=dropout_probability,
dropout_probability=dropout_probability, dtype=dtype,
dtype=dtype, is_training=is_training,
is_training=is_training, pad_ratio=pad_ratio)
pad_ratio=pad_ratio)
def grad_func(fused_attn_func, *args, **kwargs):
def grad_func(fused_attn_max_512_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident # Gradient is small, use a gradient multiplier to amplify the graident
gradient_multiplier = 1000 if dtype == jnp.bfloat16 else 10000 gradient_multiplier = 1000 if dtype == jnp.bfloat16 else 10000
if attn_mask_type == AttnMaskType.CAUSAL_MASK: if attn_mask_type == AttnMaskType.CAUSAL_MASK:
gradient_multiplier = gradient_multiplier / 10 gradient_multiplier = gradient_multiplier / 10
# Keep only valid result for the gradient # Keep only valid result for the gradient
# fused_attn_max_512 output has shape (b, s, h, d) # fused_attn output has shape (b, s, h, d)
valid_fused_attn_max_512_ret, _ = jnp.split(fused_attn_max_512_func(*args, **kwargs), valid_fused_attn_ret, _ = jnp.split(fused_attn_func(*args, **kwargs), (self.valid_len,),
(self.valid_len,), axis=1)
axis=1) return (jnp.mean(valid_fused_attn_ret, dtype=jnp.float32) *
return (jnp.mean(valid_fused_attn_max_512_ret, dtype=jnp.float32) *
gradient_multiplier).astype(dtype) gradient_multiplier).astype(dtype)
kwargs = { kwargs = {
...@@ -298,8 +305,7 @@ class TestSelfFusedAttnMax512(): ...@@ -298,8 +305,7 @@ class TestSelfFusedAttnMax512():
jitted_reference = jit( jitted_reference = jit(
value_and_grad( value_and_grad(
lambda qkv, bias, q_token, kv_token, dropout_rng: grad_func( lambda qkv, bias, q_token, kv_token, dropout_rng: grad_func(
jax_self_fused_attn, qkv, bias, q_token, kv_token, dropout_rng, **kwargs), jax_self_attn, qkv, bias, q_token, kv_token, dropout_rng, **kwargs), (0, 1)))
(0, 1)))
primitive_out, (primitive_dqkv, primitive_out, (primitive_dqkv,
primitive_dbias) = jitted_primitive(self.qkv, self.bias, self.q_token, primitive_dbias) = jitted_primitive(self.qkv, self.bias, self.q_token,
...@@ -309,6 +315,10 @@ class TestSelfFusedAttnMax512(): ...@@ -309,6 +315,10 @@ class TestSelfFusedAttnMax512():
reference_dbias) = jitted_reference(self.qkv, self.bias, self.q_token, reference_dbias) = jitted_reference(self.qkv, self.bias, self.q_token,
self.kv_token, self.dropout_rng) self.kv_token, self.dropout_rng)
# Dropout can't get the bitmatch result, skip the elementwise comparison
if dropout_probability > 0.:
return
np.testing.assert_allclose(jnp.asarray(primitive_out, np.float32), np.testing.assert_allclose(jnp.asarray(primitive_out, np.float32),
jnp.asarray(reference_out, np.float32), jnp.asarray(reference_out, np.float32),
rtol=1e-4, rtol=1e-4,
...@@ -319,23 +329,14 @@ class TestSelfFusedAttnMax512(): ...@@ -319,23 +329,14 @@ class TestSelfFusedAttnMax512():
valid_reference_dqkv, invalid_reference_dqkv = jnp.split(reference_dqkv, (self.valid_len,), valid_reference_dqkv, invalid_reference_dqkv = jnp.split(reference_dqkv, (self.valid_len,),
axis=1) axis=1)
# dQ valid_primitive_dq, valid_primitive_dk, valid_primitive_dv = jnp.split(
np.testing.assert_allclose(jnp.asarray(valid_primitive_dqkv[:, :, 0], np.float32), valid_primitive_dqkv.astype(jnp.float32), 3, axis=2)
jnp.asarray(valid_reference_dqkv[:, :, 0], np.float32), valid_reference_dq, valid_reference_dk, valid_reference_dv = jnp.split(
rtol=1e-4, valid_reference_dqkv.astype(jnp.float32), 3, axis=2)
atol=1e-5)
# dK
np.testing.assert_allclose(jnp.asarray(valid_primitive_dqkv[:, :, 1], np.float32),
jnp.asarray(valid_reference_dqkv[:, :, 1], np.float32),
rtol=1e-4,
atol=1e-5)
# dV np.testing.assert_allclose(valid_primitive_dq, valid_reference_dq, rtol=1e-4, atol=1e-5)
np.testing.assert_allclose(jnp.asarray(valid_primitive_dqkv[:, :, 2], np.float32), np.testing.assert_allclose(valid_primitive_dk, valid_reference_dk, rtol=1e-4, atol=1e-5)
jnp.asarray(valid_reference_dqkv[:, :, 2], np.float32), np.testing.assert_allclose(valid_primitive_dv, valid_reference_dv, rtol=1e-4, atol=1e-5)
rtol=1e-4,
atol=1e-5)
assert jnp.allclose(invalid_primitive_dqkv, invalid_reference_dqkv) assert jnp.allclose(invalid_primitive_dqkv, invalid_reference_dqkv)
...@@ -362,10 +363,17 @@ class TestSelfFusedAttnMax512(): ...@@ -362,10 +363,17 @@ class TestSelfFusedAttnMax512():
@pytest.mark.skipif(not is_fused_attn_kernel_available(), @pytest.mark.skipif(not is_fused_attn_kernel_available(),
reason="Fused attention kernel is not supported.") reason="Fused attention kernel is not supported.")
class TestCrossFusedAttnMax512(): @pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES)
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK])
def set_input(self, b, s_q, s_kv, h, d, *, attn_mask_type, dropout_probability, dtype, @pytest.mark.parametrize('dropout_probability', [0., 0.1])
is_training, pad_ratio): @pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', [0.3])
class TestCrossFusedAttn():
"""Tests for transformer_engine.jax.fused_attn.cross_fused_attn"""
def _set_inputs(self, b, s_q, s_kv, h, d, *, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
...@@ -385,34 +393,28 @@ class TestCrossFusedAttnMax512(): ...@@ -385,34 +393,28 @@ class TestCrossFusedAttnMax512():
self.kv_token = jnp.concatenate((jnp.ones((b, self.kv_valid_len)), jnp.zeros( self.kv_token = jnp.concatenate((jnp.ones((b, self.kv_valid_len)), jnp.zeros(
(b, kv_pad_len))), (b, kv_pad_len))),
axis=-1) axis=-1)
self.scaling_factor = 1. / math.sqrt(d) self.scaling_factor = 1. / sqrt(d)
self.dropout_probability = dropout_probability self.dropout_probability = dropout_probability
self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None
self.attn_bias_type = AttnBiasType.NO_BIAS self.attn_bias_type = AttnBiasType.NO_BIAS
self.attn_mask_type = attn_mask_type
self.is_training = is_training self.is_training = is_training
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES)
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype, def test_forward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio): is_training, pad_ratio):
# dropout can't get the bitmatch result """
if is_training and dropout_probability > 0.: Test forward without using JIT
return """
self._set_inputs(b,
self.set_input(b, s_q,
s_q, s_kv,
s_kv, h,
h, d,
d, attn_mask_type=attn_mask_type,
attn_mask_type=attn_mask_type, dropout_probability=dropout_probability,
dropout_probability=dropout_probability, dtype=dtype,
dtype=dtype, is_training=is_training,
is_training=is_training, pad_ratio=pad_ratio)
pad_ratio=pad_ratio)
primitive_out = customcall_cross_fused_attn(self.q, primitive_out = customcall_cross_fused_attn(self.q,
self.kv, self.kv,
...@@ -425,15 +427,19 @@ class TestCrossFusedAttnMax512(): ...@@ -425,15 +427,19 @@ class TestCrossFusedAttnMax512():
dropout_probability=self.dropout_probability, dropout_probability=self.dropout_probability,
is_training=self.is_training) is_training=self.is_training)
reference_out = jax_cross_fused_attn(self.q, reference_out = jax_cross_attn(self.q,
self.kv, self.kv,
self.q_token, self.q_token,
self.kv_token, self.kv_token,
self.dropout_rng, self.dropout_rng,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor, scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability, dropout_probability=self.dropout_probability,
is_training=self.is_training) is_training=self.is_training)
# Dropout can't get the bitmatch result, skip the elementwise comparison
if is_training and dropout_probability > 0.:
return
ref_valid, _ = jnp.split(reference_out, (self.q_valid_len,), axis=1) ref_valid, _ = jnp.split(reference_out, (self.q_valid_len,), axis=1)
pri_valid, pri_invalid = jnp.split(primitive_out, (self.q_valid_len,), axis=1) pri_valid, pri_invalid = jnp.split(primitive_out, (self.q_valid_len,), axis=1)
...@@ -446,36 +452,36 @@ class TestCrossFusedAttnMax512(): ...@@ -446,36 +452,36 @@ class TestCrossFusedAttnMax512():
np.testing.assert_allclose(jnp.asarray(pri_invalid, jnp.float32), np.testing.assert_allclose(jnp.asarray(pri_invalid, jnp.float32),
jnp.zeros_like(pri_invalid, jnp.float32)) jnp.zeros_like(pri_invalid, jnp.float32))
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES)
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK])
@pytest.mark.parametrize('dropout_probability', [0.]) # dropout can't get the bitmatch result
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True]) # backward is only used when is_training
@pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward_backward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype, def test_forward_backward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio): is_training, pad_ratio):
self.set_input(b, """
s_q, Test forward, backward, and autodiff by jax.value_and_grad
s_kv, """
h, if not is_training:
d, pytest.skip(f"Backward doesn't support {is_training=}")
attn_mask_type=attn_mask_type,
dropout_probability=dropout_probability, self._set_inputs(b,
dtype=dtype, s_q,
is_training=is_training, s_kv,
pad_ratio=pad_ratio) h,
d,
def grad_func(fused_attn_max_512_func, *args, **kwargs): attn_mask_type=attn_mask_type,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
def grad_func(fused_attn_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident # Gradient is small, use a gradient multiplier to amplify the graident
gradient_multiplier = 10000 gradient_multiplier = 10000
if attn_mask_type == AttnMaskType.CAUSAL_MASK: if attn_mask_type == AttnMaskType.CAUSAL_MASK:
gradient_multiplier = gradient_multiplier / 10 gradient_multiplier = gradient_multiplier / 10
# Keep only valid result for the gradient # Keep only valid result for the gradient
# fused_attn_max_512 output has shape (b, s_q, h, d) # fused_attn output has shape (b, s_q, h, d)
valid_fused_attn_max_512_ret, _ = jnp.split(fused_attn_max_512_func(*args, **kwargs), valid_fused_attn_ret, _ = jnp.split(fused_attn_func(*args, **kwargs),
(self.q_valid_len,), (self.q_valid_len,),
axis=1) axis=1)
return (jnp.mean(valid_fused_attn_max_512_ret, dtype=jnp.float32) * return (jnp.mean(valid_fused_attn_ret, dtype=jnp.float32) *
gradient_multiplier).astype(dtype) gradient_multiplier).astype(dtype)
kwargs = { kwargs = {
...@@ -496,7 +502,7 @@ class TestCrossFusedAttnMax512(): ...@@ -496,7 +502,7 @@ class TestCrossFusedAttnMax512():
jitted_reference = jit( jitted_reference = jit(
value_and_grad( value_and_grad(
lambda q, kv, q_token, kv_token, dropout_rng: grad_func( lambda q, kv, q_token, kv_token, dropout_rng: grad_func(
jax_cross_fused_attn, q, kv, q_token, kv_token, dropout_rng, **kwargs), (0, 1))) jax_cross_attn, q, kv, q_token, kv_token, dropout_rng, **kwargs), (0, 1)))
primitive_out, (primitive_dq, primitive_out, (primitive_dq,
primitive_dkv) = jitted_primitive(self.q, self.kv, self.q_token, primitive_dkv) = jitted_primitive(self.q, self.kv, self.q_token,
...@@ -506,6 +512,10 @@ class TestCrossFusedAttnMax512(): ...@@ -506,6 +512,10 @@ class TestCrossFusedAttnMax512():
reference_dkv) = jitted_reference(self.q, self.kv, self.q_token, reference_dkv) = jitted_reference(self.q, self.kv, self.q_token,
self.kv_token, self.dropout_rng) self.kv_token, self.dropout_rng)
# Dropout can't get the bitmatch result, skip the elementwise comparison
if dropout_probability > 0.:
return
np.testing.assert_allclose(jnp.asarray(primitive_out, np.float32), np.testing.assert_allclose(jnp.asarray(primitive_out, np.float32),
jnp.asarray(reference_out, np.float32), jnp.asarray(reference_out, np.float32),
rtol=1e-4, rtol=1e-4,
......
...@@ -547,7 +547,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -547,7 +547,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
if (!is_training) { if (!is_training) {
dropout_probability == 0.0f; dropout_probability = 0.0f;
} }
FADescriptor descriptor{b, h, FADescriptor descriptor{b, h,
...@@ -1144,7 +1144,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -1144,7 +1144,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
void *devPtrSoftmaxSum = static_cast<int8_t *>(workspace) + plan_workspace_size; void *devPtrSoftmaxSum = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devPtrdQAccumulator = static_cast<int8_t *>(devPtrSoftmaxSum) void *devPtrdQAccumulator = static_cast<int8_t *>(devPtrSoftmaxSum)
+ softmaxSum_workspace_size; + softmaxSum_workspace_size;
NVTE_CHECK_CUDA(cudaMemset(devPtrdQAccumulator, 0, dqAccum_workspace_size)); NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQAccumulator, 0, dqAccum_workspace_size, stream));
std::set<std::pair<uint64_t, void *>> data_ptrs; std::set<std::pair<uint64_t, void *>> data_ptrs;
// add all the data pointers to be used in the variant pack // add all the data pointers to be used in the variant pack
...@@ -1224,6 +1224,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -1224,6 +1224,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
} }
void* devPtrDropoutSeed = rng_state->data.dptr; void* devPtrDropoutSeed = rng_state->data.dptr;
...@@ -1250,6 +1252,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -1250,6 +1252,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
workspace->data.shape = {1}; workspace->data.shape = {1};
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
} }
} }
...@@ -1312,6 +1316,8 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, ...@@ -1312,6 +1316,8 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
workspace->data.shape = {1}; workspace->data.shape = {1};
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
} }
} }
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -1275,6 +1275,8 @@ void fused_attn_max_512_fwd_qkvpacked( ...@@ -1275,6 +1275,8 @@ void fused_attn_max_512_fwd_qkvpacked(
} else if (Aux_CTX_Tensors->size == 1) { } else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
} }
void *devPtrCuSeqlen = cu_seqlens->data.dptr; void *devPtrCuSeqlen = cu_seqlens->data.dptr;
...@@ -1304,6 +1306,8 @@ void fused_attn_max_512_fwd_qkvpacked( ...@@ -1304,6 +1306,8 @@ void fused_attn_max_512_fwd_qkvpacked(
workspace->data.shape = {1}; workspace->data.shape = {1};
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
} }
} }
...@@ -1351,6 +1355,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1351,6 +1355,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
} else if (Aux_CTX_Tensors->size == 1) { } else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
} }
void *devQCuSeqlen = q_cu_seqlens->data.dptr; void *devQCuSeqlen = q_cu_seqlens->data.dptr;
...@@ -1380,6 +1386,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1380,6 +1386,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
workspace->data.shape = {1}; workspace->data.shape = {1};
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
} }
} }
...@@ -1440,6 +1448,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu ...@@ -1440,6 +1448,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
workspace->data.shape = {1}; workspace->data.shape = {1};
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
} }
} }
...@@ -1503,6 +1513,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1503,6 +1513,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
workspace->data.shape = {1}; workspace->data.shape = {1};
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
} }
} }
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -15,7 +15,7 @@ from jaxlib.hlo_helpers import custom_call ...@@ -15,7 +15,7 @@ from jaxlib.hlo_helpers import custom_call
import jax.numpy as jnp import jax.numpy as jnp
from jax.lib import xla_client from jax.lib import xla_client
from jax import core, dtypes from jax import core, dtypes
from jax.abstract_arrays import ShapedArray from jax.core import ShapedArray
from jax.interpreters import xla, mlir from jax.interpreters import xla, mlir
from jax.interpreters.mlir import ir, dtype_to_ir_type from jax.interpreters.mlir import ir, dtype_to_ir_type
...@@ -23,6 +23,8 @@ import transformer_engine_jax ...@@ -23,6 +23,8 @@ import transformer_engine_jax
from transformer_engine_jax import DType as TEDType from transformer_engine_jax import DType as TEDType
from transformer_engine_jax import NVTE_Bias_Type from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend
for _name, _value in transformer_engine_jax.registrations().items(): for _name, _value in transformer_engine_jax.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA") xla_client.register_custom_call_target(_name, _value, platform="CUDA")
...@@ -1981,32 +1983,50 @@ def scaled_upper_triang_masked_softmax_bwd(grad_outputs: jnp.ndarray, softmax_ou ...@@ -1981,32 +1983,50 @@ def scaled_upper_triang_masked_softmax_bwd(grad_outputs: jnp.ndarray, softmax_ou
scale_factor=scale_factor) scale_factor=scale_factor)
def _check_seed(seed, dropout_probability, is_training): @dataclass(frozen=True)
# Jax can't bind None, create a dummy tensor for None class _FusedAttnRNGStateChecker:
if seed is None: """
dropout_enabled = dropout_probability > 0 and is_training Checker for guarding the fused attention rng state.
assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled." The fused attention backend requires a 64 bits seed and a 64 bits offset.
seed = jnp.zeros(2, dtype=jnp.uint32) However, JAX doesn't enable 64 bits by default,
so we have to emulate seed as two 32 bits array.
The offset calculation is maintained in the backend.
"""
rng_state_dtype: jnp.dtype = jnp.uint32
# (seed,) with internal dtype int64
seed_size: int = 2
# (seed, offset) with internal dtype int64
rng_state_size: int = 2 * 2
if seed.dtype != jnp.uint32: def check_seed(self, seed, dropout_probability, is_training):
warnings.warn( """
f"Requested {seed.dtype=} is not available, and will be " Check the seed and convert the data type of seed if possible.
f"casted to dtype uint32. " """
f"Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning.") # Jax can't bind None, create a dummy tensor for None
seed = seed.astype(jnp.uint32) if seed is None:
dropout_enabled = dropout_probability > 0 and is_training
assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled."
seed = jnp.zeros(2, dtype=self.rng_state_dtype)
assert seed.dtype == jnp.uint32 if seed.dtype != self.rng_state_dtype:
# Only the first 2 u32 elements are taken warnings.warn(
assert seed.size >= 2 f"Requested {seed.dtype=} is not available, and will be "
f"casted to dtype {self.rng_state_dtype}. "
f"Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning.")
seed = seed.astype(self.rng_state_dtype)
return seed assert seed.dtype == self.rng_state_dtype
# Backend takes an int64_t seed, so only the first two u32 elements are taken
assert seed.size >= self.seed_size
return seed
class SelfFusedAttnMax512FwdPrimitive(BasePrimitive):
class SelfFusedAttnFwdPrimitive(BasePrimitive):
""" """
Self Fused Attention Max Seqlen 512 Forward Primitive Self Fused Attention Forward Primitive
""" """
name = "te_self_fused_attn_max_512_forward" name = "te_self_fused_attn_forward"
multiple_results = True multiple_results = True
@staticmethod @staticmethod
...@@ -2023,7 +2043,7 @@ class SelfFusedAttnMax512FwdPrimitive(BasePrimitive): ...@@ -2023,7 +2043,7 @@ class SelfFusedAttnMax512FwdPrimitive(BasePrimitive):
is_training # pylint: disable=unused-argument is_training # pylint: disable=unused-argument
): ):
""" """
Self fused attention max seqlen 512 fwd abstract Self fused attention fwd abstract
""" """
qkv_dtype = dtypes.canonicalize_dtype(qkv.dtype) qkv_dtype = dtypes.canonicalize_dtype(qkv.dtype)
batch, max_seqlen, nqkv, num_head, head_dim = qkv.shape batch, max_seqlen, nqkv, num_head, head_dim = qkv.shape
...@@ -2033,99 +2053,102 @@ class SelfFusedAttnMax512FwdPrimitive(BasePrimitive): ...@@ -2033,99 +2053,102 @@ class SelfFusedAttnMax512FwdPrimitive(BasePrimitive):
output_shape = (batch, max_seqlen, num_head, head_dim) output_shape = (batch, max_seqlen, num_head, head_dim)
output_dtype = qkv_dtype output_dtype = qkv_dtype
softmax_aux_shape = (batch, num_head, max_seqlen, max_seqlen) backend = transformer_engine_jax.get_fused_attn_backend(
softmax_dtype = qkv_dtype jax_dtype_to_te_dtype(qkv_dtype), jax_dtype_to_te_dtype(qkv_dtype),
NVTE_QKV_Layout.NVTE_QKV_INTERLEAVED, attn_bias_type, attn_mask_type,
dropout_probability, max_seqlen, max_seqlen, head_dim)
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_aux_shape = (batch, num_head, max_seqlen, max_seqlen)
softmax_dtype = qkv_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_aux_shape = (batch, num_head, max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f'Not supported {backend=}')
checker = _FusedAttnRNGStateChecker()
seed_dtype = dtypes.canonicalize_dtype(seed.dtype)
assert seed_dtype == checker.rng_state_dtype
rng_state_shape = (checker.rng_state_size,)
rng_state_dtype = seed_dtype
return ( return (
ShapedArray(output_shape, output_dtype, named_shape=qkv.named_shape), # output ShapedArray(output_shape, output_dtype, named_shape=qkv.named_shape), # output
ShapedArray(softmax_aux_shape, softmax_dtype, ShapedArray(softmax_aux_shape, softmax_dtype,
named_shape=qkv.named_shape), # softmax_aux named_shape=qkv.named_shape), # softmax_aux
ShapedArray(rng_state_shape, rng_state_dtype,
named_shape=seed.named_shape), # rng_state
) )
@staticmethod @staticmethod
def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, scaling_factor, def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training): dropout_probability, is_training):
""" """
Self fused attention max seqlen 512 fwd lowering rules Self fused attention fwd lowering rules
""" """
qkv_aval, _, _, _ = ctx.avals_in qkv_aval, _, _, _ = ctx.avals_in
ir_qkv_type = ir.RankedTensorType(qkv.type) batch, max_seqlen, _, num_head, head_dim = qkv_aval.shape
ir_qkv_shape = ir_qkv_type.shape
ir_bias_type = ir.RankedTensorType(bias.type)
ir_bias_shape = ir_bias_type.shape
ir_cu_seqlen_type = ir.RankedTensorType(cu_seqlen.type)
ir_cu_seqlen_shape = ir_cu_seqlen_type.shape
ir_seed_type = ir.RankedTensorType(seed.type)
ir_seed_shape = ir_seed_type.shape
batch, max_seqlen, nqkv, num_head, head_dim = ir_qkv_shape
assert nqkv == 3
output_shape = (batch, max_seqlen, num_head, head_dim)
softmax_aux_shape = (batch, num_head, max_seqlen, max_seqlen)
operands = [qkv, bias, cu_seqlen, seed]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [ out_types = [
ir.RankedTensorType.get(output_shape, ir_qkv_type.element_type), ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
ir.RankedTensorType.get(softmax_aux_shape, ir_qkv_type.element_type) for output in ctx.avals_out
] ]
operands = [qkv, bias, cu_seqlen, seed]
operand_shapes = [ir_qkv_shape, ir_bias_shape, ir_cu_seqlen_shape, ir_seed_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability, batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training) attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
out = custom_caller(SelfFusedAttnMax512FwdPrimitive.name, out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
args,
opaque,
has_side_effect=False)
return out return out
_self_fused_attn_max_512_fwd_p = register_primitive(SelfFusedAttnMax512FwdPrimitive) _self_fused_attn_fwd_p = register_primitive(SelfFusedAttnFwdPrimitive)
def self_fused_attn_max_512_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, cu_seqlen: jnp.ndarray, def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, cu_seqlen: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float, attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool): dropout_probability: float, is_training: bool):
""" """
Wrapper for TE self fused attention max seqlen 512 fwd Wrapper for TE self fused attention fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
""" """
seed = _check_seed(seed, dropout_probability, is_training) checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
if bias is None: if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS assert bias is None
bias = jnp.zeros(0, dtype=qkv.dtype) bias = jnp.zeros(0, dtype=qkv.dtype)
return _self_fused_attn_max_512_fwd_p.bind(qkv, return _self_fused_attn_fwd_p.bind(qkv,
bias, bias,
cu_seqlen, cu_seqlen,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
class SelfFusedAttnMax512BwdPrimitive(BasePrimitive): class SelfFusedAttnBwdPrimitive(BasePrimitive):
""" """
Self Fused Attention Max Seqlen 512 Backward Primitive Self Fused Attention Backward Primitive
""" """
name = "te_self_fused_attn_max_512_backward" name = "te_self_fused_attn_backward"
multiple_results = True multiple_results = True
@staticmethod @staticmethod
def abstract( def abstract(
qkv, qkv,
softmax_aux, softmax_aux, # pylint: disable=unused-argument
rng_state, # pylint: disable=unused-argument
output, # pylint: disable=unused-argument
doutput, doutput,
cu_seqlen, # pylint: disable=unused-argument cu_seqlen, # pylint: disable=unused-argument
*, *,
...@@ -2139,11 +2162,14 @@ class SelfFusedAttnMax512BwdPrimitive(BasePrimitive): ...@@ -2139,11 +2162,14 @@ class SelfFusedAttnMax512BwdPrimitive(BasePrimitive):
Self fused attention bwd abstract Self fused attention bwd abstract
""" """
qkv_dtype = dtypes.canonicalize_dtype(qkv.dtype) qkv_dtype = dtypes.canonicalize_dtype(qkv.dtype)
assert qkv.dtype == softmax_aux.dtype == doutput.dtype assert qkv.dtype == doutput.dtype
_, seqlen, _, num_head, _ = qkv.shape _, seqlen, _, num_head, _ = qkv.shape
bias_shape = (1, num_head, seqlen, seqlen) if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_shape = (0,)
else:
bias_shape = (1, num_head, seqlen, seqlen)
bias_dtype = qkv_dtype bias_dtype = qkv_dtype
return ( return (
...@@ -2151,80 +2177,62 @@ class SelfFusedAttnMax512BwdPrimitive(BasePrimitive): ...@@ -2151,80 +2177,62 @@ class SelfFusedAttnMax512BwdPrimitive(BasePrimitive):
ShapedArray(bias_shape, bias_dtype, named_shape=qkv.named_shape)) ShapedArray(bias_shape, bias_dtype, named_shape=qkv.named_shape))
@staticmethod @staticmethod
def lowering(ctx, qkv, softmax_aux, doutput, cu_seqlen, *, attn_bias_type, attn_mask_type, def lowering(ctx, qkv, softmax_aux, rng_state, output, doutput, cu_seqlen, *, attn_bias_type,
scaling_factor, dropout_probability, is_training): attn_mask_type, scaling_factor, dropout_probability, is_training):
""" """
Self fused attention max seqlen 512 bwd lowering rules Self fused attention bwd lowering rules
""" """
qkv_aval, _, _, _ = ctx.avals_in qkv_aval, _, _, _, _, _ = ctx.avals_in
ir_qkv_type = ir.RankedTensorType(qkv.type)
ir_qkv_shape = ir_qkv_type.shape
ir_softmax_aux_type = ir.RankedTensorType(softmax_aux.type)
ir_softmax_aux_shape = ir_softmax_aux_type.shape
ir_doutput_type = ir.RankedTensorType(doutput.type)
ir_doutput_shape = ir_doutput_type.shape
ir_cu_seqlen_type = ir.RankedTensorType(cu_seqlen.type)
ir_cu_seqlen_shape = ir_cu_seqlen_type.shape
batch, max_seqlen, num_head, head_dim = ir_doutput_shape batch, max_seqlen, _, num_head, head_dim = qkv_aval.shape
dbias_shape = (1, num_head, max_seqlen, max_seqlen)
dbias_dtype = ir_qkv_type.element_type
operands = [qkv, softmax_aux, rng_state, output, doutput, cu_seqlen]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [ out_types = [
ir.RankedTensorType.get(ir_qkv_shape, ir_qkv_type.element_type), ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
ir.RankedTensorType.get(dbias_shape, dbias_dtype) for output in ctx.avals_out
] ]
operands = [qkv, softmax_aux, doutput, cu_seqlen]
operand_shapes = [ir_qkv_shape, ir_softmax_aux_shape, ir_doutput_shape, ir_cu_seqlen_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
# the dropout elements are encoded in the forward auxiliary tensor
# so seed is not needed in backward
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability, batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training) attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
out = custom_caller(SelfFusedAttnMax512BwdPrimitive.name, out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
args,
opaque,
has_side_effect=False)
return out return out
_self_fused_attn_max_512_bwd_p = register_primitive(SelfFusedAttnMax512BwdPrimitive) _self_fused_attn_bwd_p = register_primitive(SelfFusedAttnBwdPrimitive)
def self_fused_attn_max_512_bwd(qkv: jnp.ndarray, softmax_aux: jnp.ndarray, doutput: jnp.ndarray, def self_fused_attn_bwd(qkv: jnp.ndarray, softmax_aux: jnp.ndarray, rng_state: jnp.ndarray,
cu_seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, output: jnp.ndarray, doutput: jnp.ndarray, cu_seqlen: jnp.ndarray,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float, attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
dropout_probability: float, is_training: bool): scaling_factor: float, dropout_probability: float, is_training: bool):
""" """
Wrapper for TE self fused attention max seqlen 512 bwd Wrapper for TE self fused attention bwd
Return the gradients of self fused attention with packed qkv input Return the gradients of self fused attention with packed qkv input
""" """
return _self_fused_attn_max_512_bwd_p.bind(qkv, return _self_fused_attn_bwd_p.bind(qkv,
softmax_aux, softmax_aux,
doutput, rng_state,
cu_seqlen, output,
attn_bias_type=attn_bias_type, doutput,
attn_mask_type=attn_mask_type, cu_seqlen,
scaling_factor=scaling_factor, attn_bias_type=attn_bias_type,
dropout_probability=dropout_probability, attn_mask_type=attn_mask_type,
is_training=is_training) scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
class CrossFusedAttnMax512FwdPrimitive(BasePrimitive): class CrossFusedAttnFwdPrimitive(BasePrimitive):
""" """
Cross Fused Attention Forward Max Seqlen 512 Primitive Cross Fused Attention Forward Primitive
""" """
name = "te_cross_fused_attn_max_512_forward" name = "te_cross_fused_attn_forward"
multiple_results = True multiple_results = True
@staticmethod @staticmethod
...@@ -2242,7 +2250,7 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive): ...@@ -2242,7 +2250,7 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive):
is_training # pylint: disable=unused-argument is_training # pylint: disable=unused-argument
): ):
""" """
Cross fused attention max seqlen 512 fwd abstract Cross fused attention fwd abstract
""" """
q_dtype = dtypes.canonicalize_dtype(q.dtype) q_dtype = dtypes.canonicalize_dtype(q.dtype)
batch_q, q_max_seqlen, num_head_q, head_dim_q = q.shape batch_q, q_max_seqlen, num_head_q, head_dim_q = q.shape
...@@ -2271,36 +2279,19 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive): ...@@ -2271,36 +2279,19 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive):
def lowering(ctx, q, kv, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type, attn_mask_type, def lowering(ctx, q, kv, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training): scaling_factor, dropout_probability, is_training):
""" """
Cross fused attention max seqlen 512 fwd lowering rules Cross fused attention fwd lowering rules
""" """
q_aval, kv_aval, _, _, _ = ctx.avals_in q_aval, kv_aval, _, _, _ = ctx.avals_in
assert q_aval.dtype == kv_aval.dtype assert q_aval.dtype == kv_aval.dtype
ir_q_type = ir.RankedTensorType(q.type) batch, q_max_seqlen, num_head, head_dim = q_aval.shape
ir_q_shape = ir_q_type.shape kv_max_seqlen = kv_aval.shape[1]
ir_kv_type = ir.RankedTensorType(kv.type)
ir_kv_shape = ir_kv_type.shape
ir_q_cu_seqlen_shape = ir.RankedTensorType(q_cu_seqlen.type).shape
ir_kv_cu_seqlen_shape = ir.RankedTensorType(kv_cu_seqlen.type).shape
ir_seed_type = ir.RankedTensorType(seed.type)
ir_seed_shape = ir_seed_type.shape
batch, q_max_seqlen, num_head, head_dim = ir_q_shape
kv_max_seqlen = ir_kv_shape[1]
output_shape = (batch, q_max_seqlen, num_head, head_dim)
softmax_aux_shape = (batch, num_head, q_max_seqlen, kv_max_seqlen)
out_types = [
ir.RankedTensorType.get(output_shape, ir_q_type.element_type),
ir.RankedTensorType.get(softmax_aux_shape, ir_q_type.element_type)
]
operands = [q, kv, q_cu_seqlen, kv_cu_seqlen, seed] operands = [q, kv, q_cu_seqlen, kv_cu_seqlen, seed]
operand_shapes = [ operand_shapes = map(lambda x: x.type.shape, operands)
ir_q_shape, ir_kv_shape, ir_q_cu_seqlen_shape, ir_kv_cu_seqlen_shape, ir_seed_shape out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
] ]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
...@@ -2309,45 +2300,42 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive): ...@@ -2309,45 +2300,42 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive):
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training) jax_dtype_to_te_dtype(q_aval.dtype), is_training)
out = custom_caller(CrossFusedAttnMax512FwdPrimitive.name, out = custom_caller(CrossFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
args,
opaque,
has_side_effect=False)
return out return out
_cross_fused_attn_max_512_fwd_p = register_primitive(CrossFusedAttnMax512FwdPrimitive) _cross_fused_attn_fwd_p = register_primitive(CrossFusedAttnFwdPrimitive)
def cross_fused_attn_max_512_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_cu_seqlen: jnp.ndarray, def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_cu_seqlen: jnp.ndarray,
kv_cu_seqlen: jnp.ndarray, seed: jnp.ndarray, kv_cu_seqlen: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float, scaling_factor: float, dropout_probability: float, is_training: bool):
is_training: bool):
""" """
Wrapper for TE cross fused attention max seqlen 512 fwd Wrapper for TE cross fused attention fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
""" """
seed = _check_seed(seed, dropout_probability, is_training) checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
return _cross_fused_attn_max_512_fwd_p.bind(q, return _cross_fused_attn_fwd_p.bind(q,
kv, kv,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
class CrossFusedAttnMax512BwdPrimitive(BasePrimitive): class CrossFusedAttnBwdPrimitive(BasePrimitive):
""" """
Cross Fused Attention Max Seqlen 512 Backward Primitive Cross Fused Attention Backward Primitive
""" """
name = "te_cross_fused_attn_max_512_backward" name = "te_cross_fused_attn_backward"
multiple_results = True multiple_results = True
@staticmethod @staticmethod
...@@ -2366,7 +2354,7 @@ class CrossFusedAttnMax512BwdPrimitive(BasePrimitive): ...@@ -2366,7 +2354,7 @@ class CrossFusedAttnMax512BwdPrimitive(BasePrimitive):
is_training # pylint: disable=unused-argument is_training # pylint: disable=unused-argument
): ):
""" """
Cross fused attention max seqlen 512 bwd abstract Cross fused attention bwd abstract
""" """
q_dtype = dtypes.canonicalize_dtype(q.dtype) q_dtype = dtypes.canonicalize_dtype(q.dtype)
kv_dtype = dtypes.canonicalize_dtype(kv.dtype) kv_dtype = dtypes.canonicalize_dtype(kv.dtype)
...@@ -2384,34 +2372,19 @@ class CrossFusedAttnMax512BwdPrimitive(BasePrimitive): ...@@ -2384,34 +2372,19 @@ class CrossFusedAttnMax512BwdPrimitive(BasePrimitive):
def lowering(ctx, q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen, *, attn_bias_type, def lowering(ctx, q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training): attn_mask_type, scaling_factor, dropout_probability, is_training):
""" """
Cross fused attention max seqlen 512 bwd lowering rules Cross fused attention bwd lowering rules
""" """
q_aval, _, _, _, _, _ = ctx.avals_in q_aval, kv_aval, _, _, _, _ = ctx.avals_in
assert q_aval.dtype == kv_aval.dtype
ir_q_type = ir.RankedTensorType(q.type)
ir_q_shape = ir_q_type.shape
ir_kv_type = ir.RankedTensorType(kv.type)
ir_kv_shape = ir_kv_type.shape
ir_softmax_aux_type = ir.RankedTensorType(softmax_aux.type)
ir_softmax_aux_shape = ir_softmax_aux_type.shape
ir_doutput_shape = ir.RankedTensorType(doutput.type).shape
ir_q_cu_seqlen_shape = ir.RankedTensorType(q_cu_seqlen.type).shape
ir_kv_cu_seqlen_shape = ir.RankedTensorType(kv_cu_seqlen.type).shape
batch, q_max_seqlen, num_head, head_dim = ir_doutput_shape batch, q_max_seqlen, num_head, head_dim = q_aval.shape
kv_max_seqlen = ir_kv_shape[1] kv_max_seqlen = kv_aval.shape[1]
out_types = [
ir.RankedTensorType.get(ir_q_shape, ir_q_type.element_type),
ir.RankedTensorType.get(ir_kv_shape, ir_kv_type.element_type),
]
operands = [q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen] operands = [q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen]
operand_shapes = [ operand_shapes = map(lambda x: x.type.shape, operands)
ir_q_shape, ir_kv_shape, ir_softmax_aux_shape, ir_doutput_shape, ir_q_cu_seqlen_shape, out_types = [
ir_kv_cu_seqlen_shape ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
] ]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
...@@ -2423,34 +2396,30 @@ class CrossFusedAttnMax512BwdPrimitive(BasePrimitive): ...@@ -2423,34 +2396,30 @@ class CrossFusedAttnMax512BwdPrimitive(BasePrimitive):
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training) jax_dtype_to_te_dtype(q_aval.dtype), is_training)
out = custom_caller(CrossFusedAttnMax512BwdPrimitive.name, out = custom_caller(CrossFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
args,
opaque,
has_side_effect=False)
return out return out
_cross_fused_attn_max_512_bwd_p = register_primitive(CrossFusedAttnMax512BwdPrimitive) _cross_fused_attn_bwd_p = register_primitive(CrossFusedAttnBwdPrimitive)
def cross_fused_attn_max_512_bwd(q: jnp.ndarray, kv: jnp.ndarray, softmax_aux: jnp.ndarray, def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, softmax_aux: jnp.ndarray,
doutput: jnp.ndarray, q_cu_seqlen: jnp.ndarray, doutput: jnp.ndarray, q_cu_seqlen: jnp.ndarray, kv_cu_seqlen: jnp.ndarray,
kv_cu_seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float, scaling_factor: float, dropout_probability: float, is_training: bool):
dropout_probability: float, is_training: bool):
""" """
Wrapper for TE cross fused attention max seqlen 512 bwd Wrapper for TE cross fused attention bwd
Return the gradients of cross fused attention with packed kv input Return the gradients of cross fused attention with packed kv input
""" """
return _cross_fused_attn_max_512_bwd_p.bind(q, return _cross_fused_attn_bwd_p.bind(q,
kv, kv,
softmax_aux, softmax_aux,
doutput, doutput,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
...@@ -46,11 +46,10 @@ pybind11::dict Registrations() { ...@@ -46,11 +46,10 @@ pybind11::dict Registrations() {
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward); EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward);
dict["te_scaled_upper_triang_masked_softmax_backward"] = dict["te_scaled_upper_triang_masked_softmax_backward"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward); EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward);
dict["te_self_fused_attn_max_512_forward"] = EncapsulateFunction(SelfFusedAttnMax512Forward); dict["te_self_fused_attn_forward"] = EncapsulateFunction(SelfFusedAttnForward);
dict["te_self_fused_attn_max_512_backward"] = EncapsulateFunction(SelfFusedAttnMax512Backward); dict["te_self_fused_attn_backward"] = EncapsulateFunction(SelfFusedAttnBackward);
dict["te_cross_fused_attn_max_512_forward"] = EncapsulateFunction(CrossFusedAttnMax512Forward); dict["te_cross_fused_attn_forward"] = EncapsulateFunction(CrossFusedAttnForward);
dict["te_cross_fused_attn_max_512_backward"] = dict["te_cross_fused_attn_backward"] = EncapsulateFunction(CrossFusedAttnBackward);
EncapsulateFunction(CrossFusedAttnMax512Backward);
return dict; return dict;
} }
...@@ -65,6 +64,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -65,6 +64,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_device_compute_capability", &GetDeviceComputeCapability); m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
m.def("is_fused_attn_kernel_available", &IsFusedAttnKernelAvailable); m.def("is_fused_attn_kernel_available", &IsFusedAttnKernelAvailable);
m.def("get_fused_attn_backend", &GetFusedAttnBackend);
pybind11::enum_<DType>(m, "DType", pybind11::module_local()) pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte) .value("kByte", DType::kByte)
...@@ -85,6 +85,17 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -85,6 +85,17 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK); .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK);
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local())
.value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED)
.value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
.value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED);
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend)
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8);
} }
} // namespace jax } // namespace jax
......
...@@ -740,8 +740,19 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, ...@@ -740,8 +740,19 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
desc.scale_factor, stream); desc.scale_factor, stream);
} }
void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
size_t opaque_len) { NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim) {
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim);
return backend;
}
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor = const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
...@@ -754,12 +765,17 @@ void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char ...@@ -754,12 +765,17 @@ void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
// output // output
void *output = buffers[4]; void *output = buffers[4];
void *softmax_aux = buffers[5]; void *softmax_aux = buffers[5];
void *rng_state = buffers[6];
auto batch = descriptor.batch; auto batch = descriptor.batch;
auto num_head = descriptor.num_head; auto num_head = descriptor.num_head;
auto q_max_seqlen = descriptor.q_max_seqlen; auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen; auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim; auto head_dim = descriptor.head_dim;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED;
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention."); "q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
...@@ -768,78 +784,81 @@ void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char ...@@ -768,78 +784,81 @@ void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim}; auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen}; auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
// input tensors
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
// FP16/BF16 doesn't use this tensor // output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor = auto o_tensor =
TensorWrapper(output, std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}, dtype); TensorWrapper(output, std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}, dtype);
auto cu_seqlens_tensor = // F16 doesn't use this tensor
TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32); auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64); // aux tensors
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
&aux_output_tensors, cu_seqlens_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
descriptor.is_training, descriptor.scaling_factor, descriptor.dropout_probability, descriptor.scaling_factor, dropout_probability, qkv_layout,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type, bias_type, mask_type, query_workspace_tensor.data(), stream);
query_workspace_tensor.data(), stream);
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]); auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux; output_s->data.dptr = softmax_aux;
// fused attn workspace + workspace for rng_state auto workspace_size = query_workspace_tensor.shape().data[0];
auto plan_workspace_size = auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto rng_workspace_size = 2 * sizeof(int64_t);
auto total_workspace_size = plan_workspace_size + rng_workspace_size;
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(total_workspace_size);
auto workspace_tensor = auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype()); TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
auto rng_state = static_cast<uint8_t *>(workspace) + plan_workspace_size;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, stream);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(), o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training, rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability, descriptor.scaling_factor, dropout_probability, qkv_layout,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type, bias_type, mask_type, workspace_tensor.data(), stream);
descriptor.mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_output_tensors);
} }
void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque, void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) { size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor = const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input // input
void *qkv = buffers[0]; void *qkv = buffers[0];
void *softmax_aux = buffers[1]; void *softmax_aux = buffers[1];
void *doutput = buffers[2]; void *rng_state = buffers[2];
void *cu_seqlens = buffers[3]; void *output = buffers[3];
void *doutput = buffers[4];
void *cu_seqlens = buffers[5];
// output // output
void *dqkv = buffers[4]; void *dqkv = buffers[6];
void *dp = softmax_aux; void *dbias = buffers[7];
void *dbias = buffers[5];
auto batch = descriptor.batch; auto batch = descriptor.batch;
auto num_head = descriptor.num_head; auto num_head = descriptor.num_head;
auto q_max_seqlen = descriptor.q_max_seqlen; auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen; auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim; auto head_dim = descriptor.head_dim;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED;
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention."); "q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
...@@ -850,11 +869,9 @@ void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char ...@@ -850,11 +869,9 @@ void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen}; auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// It's a little trick that the flash attn needs fwd output // F16 doesn't use this tensor
// But when seqlen <= 512, it is not needed
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
// FP16/BF16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype); auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
...@@ -862,50 +879,47 @@ void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char ...@@ -862,50 +879,47 @@ void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char
auto cu_seqlens_tensor = auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32); TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
// Currently, no rng_state required for bwd
auto rng_state = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt64);
// TODO: needs to think about how to pass aux_output_tensors // TODO: needs to think about how to pass aux_output_tensors
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
aux_output_tensors.size = 1; aux_output_tensors.size = 2;
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]); auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.shape = std::vector<size_t>{batch, num_head, q_max_seqlen, kv_max_seqlen};
output_s->data.dptr = softmax_aux; output_s->data.dptr = softmax_aux;
auto *rng_state_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[1]);
rng_state_tensor->data.shape = std::vector<size_t>{2};
rng_state_tensor->data.dtype = DType::kInt64;
rng_state_tensor->data.dptr = rng_state;
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for FP16/BF16 s_tensor.data(), // not used for F16
&aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(), &aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor, cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor,
descriptor.dropout_probability, dropout_probability, qkv_layout, bias_type, mask_type,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type, query_workspace_tensor.data(), stream);
descriptor.mask_type, query_workspace_tensor.data(), stream);
size_t workspace_size = size_t workspace_size = query_workspace_tensor.shape().data[0];
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size); auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor = auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype()); TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for FP16/BF16 s_tensor.data(), // not used for F16
&aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(), &aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor, cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor,
descriptor.dropout_probability, dropout_probability, qkv_layout, bias_type, mask_type,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type, workspace_tensor.data(), stream);
descriptor.mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_output_tensors);
} }
void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) { size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor = const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
...@@ -925,6 +939,10 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char ...@@ -925,6 +939,10 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
auto q_max_seqlen = descriptor.q_max_seqlen; auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen; auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim; auto head_dim = descriptor.head_dim;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_KV_INTERLEAVED;
auto dtype = descriptor.dtype; auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}; auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
...@@ -958,8 +976,7 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char ...@@ -958,8 +976,7 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability, descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type,
query_workspace_tensor.data(), stream); query_workspace_tensor.data(), stream);
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]); auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
...@@ -976,21 +993,24 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char ...@@ -976,21 +993,24 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
auto rng_state = static_cast<uint8_t *>(workspace) + plan_workspace_size; auto rng_state = static_cast<uint8_t *>(workspace) + plan_workspace_size;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64); auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, stream);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability, descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type,
workspace_tensor.data(), stream); workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_output_tensors);
} }
void CrossFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) { size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor = const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
......
...@@ -116,6 +116,12 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( ...@@ -116,6 +116,12 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
bool IsFusedAttnKernelAvailable(); bool IsFusedAttnKernelAvailable();
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim);
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
...@@ -166,17 +172,17 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, ...@@ -166,17 +172,17 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len); std::size_t opaque_len);
void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque, void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque, void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
void CrossFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -41,9 +41,15 @@ __global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t ...@@ -41,9 +41,15 @@ __global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t
} }
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen, void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, cudaStream_t stream) { size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
constexpr int threads_per_cta = 128; cudaStream_t stream) {
const size_t increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta; size_t increment = 0;
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
increment = 16;
} else {
constexpr int threads_per_cta = 128;
increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta;
}
auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment); auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment);
populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast<int64_t *>(rng_state_dst), populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast<int64_t *>(rng_state_dst),
reinterpret_cast<const int64_t *>(seed), offset); reinterpret_cast<const int64_t *>(seed), offset);
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include "transformer_engine/fused_attn.h"
#include "transformer_engine/logging.h" #include "transformer_engine/logging.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -22,7 +23,8 @@ int GetCudaRuntimeVersion(); ...@@ -22,7 +23,8 @@ int GetCudaRuntimeVersion();
int GetDeviceComputeCapability(int gpu_id); int GetDeviceComputeCapability(int gpu_id);
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen, void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, cudaStream_t stream); size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream);
class cublasLtMetaManager { class cublasLtMetaManager {
public: public:
......
...@@ -178,7 +178,8 @@ def core_attention(query: Array, ...@@ -178,7 +178,8 @@ def core_attention(query: Array,
attn_weights = Softmax(softmax_type=softmax_type, attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor, scale_factor=fused_scale_factor,
sharding_type=softmax_sharding_type)(attn_weights, mask, bias) sharding_type=softmax_sharding_type)(attn_weights, mask,
bias).astype(dtype)
if not deterministic and dropout_rate > 0.: if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate keep_prob = 1.0 - dropout_rate
...@@ -369,12 +370,20 @@ class MultiHeadAttention(nn.Module): ...@@ -369,12 +370,20 @@ class MultiHeadAttention(nn.Module):
canonicalize_dtype = dtypes.canonicalize_dtype(self.dtype) canonicalize_dtype = dtypes.canonicalize_dtype(self.dtype)
q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1] q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1]
kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1] kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1]
fused_attn_supported_seqlen = [128, 256, 384, 512]
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0")) enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))
def _check_seqlen(seqlen):
return seqlen % 64 == 0
def _check_head_dim(head_dim):
return head_dim in [64, 128]
use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \ use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \ canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \
q_seqlen in fused_attn_supported_seqlen and kv_seqlen in fused_attn_supported_seqlen \ _check_seqlen(q_seqlen) and _check_seqlen(kv_seqlen) and \
and is_fused_attn_kernel_available() and (self.head_dim == 64) and enable_fused_attn _check_head_dim(self.head_dim) and \
is_fused_attn_kernel_available() and \
enable_fused_attn
if enable_fused_attn and not use_fused_attn: if enable_fused_attn and not use_fused_attn:
reason = "" reason = ""
...@@ -388,16 +397,16 @@ class MultiHeadAttention(nn.Module): ...@@ -388,16 +397,16 @@ class MultiHeadAttention(nn.Module):
if canonicalize_dtype not in [jnp.bfloat16, jnp.float16]: if canonicalize_dtype not in [jnp.bfloat16, jnp.float16]:
reason += f"dtype in [BF16, FP16] is required " \ reason += f"dtype in [BF16, FP16] is required " \
f"but got dtype={canonicalize_dtype}, " f"but got dtype={canonicalize_dtype}, "
if q_seqlen not in fused_attn_supported_seqlen: if not _check_seqlen(q_seqlen):
reason += f"q_seqlen in {fused_attn_supported_seqlen} is required " \ reason += f"q_seqlen % 64 == 0 is required " \
f"but got {q_seqlen=}, " f"but got {q_seqlen=}, "
if kv_seqlen not in fused_attn_supported_seqlen: if not _check_seqlen(kv_seqlen):
reason += f"kv_seqlen in {fused_attn_supported_seqlen} is required " \ reason += f"kv_seqlen % 64 == 0 is required " \
f"but got {kv_seqlen=}, " f"but got {kv_seqlen=}, "
if not _check_head_dim(self.head_dim):
reason += f"head_dim should be 64 or 128 but got {self.head_dim}, "
if not is_fused_attn_kernel_available(): if not is_fused_attn_kernel_available():
reason += "GPU arch >= Ampere and cuDNN >= 8.9.1 are required, " reason += "GPU arch >= Ampere and cuDNN >= 8.9.1 are required, "
if self.head_dim != 64:
reason += f"head_dim should be 64 but got {self.head_dim}, "
warnings.warn( warnings.warn(
f"Fused attention is not enabled, " \ f"Fused attention is not enabled, " \
......
...@@ -12,8 +12,8 @@ import transformer_engine_jax ...@@ -12,8 +12,8 @@ import transformer_engine_jax
from transformer_engine_jax import NVTE_Bias_Type from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type from transformer_engine_jax import NVTE_Mask_Type
from .cpp_extensions import cross_fused_attn_max_512_fwd, cross_fused_attn_max_512_bwd from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_max_512_fwd, self_fused_attn_max_512_bwd from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
from .sharding import get_fused_attn_sharding_meta from .sharding import get_fused_attn_sharding_meta
from .sharding import ShardingType from .sharding import ShardingType
from .sharding import xmap_runner from .sharding import xmap_runner
...@@ -57,18 +57,18 @@ def self_fused_attn(qkv: jnp.ndarray, ...@@ -57,18 +57,18 @@ def self_fused_attn(qkv: jnp.ndarray,
Self fused attention wrapper Self fused attention wrapper
""" """
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \ assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"Fused_attn_max_512 does not support row-split tensor parallelism currently." "self_fused_attn does not support row-split tensor parallelism currently."
if sharding_type is ShardingType.SINGLE: if sharding_type is ShardingType.SINGLE:
output = _self_fused_attn_max_512(qkv, output = _self_fused_attn(qkv,
bias, bias,
mask, mask,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
else: else:
dp_axis_name = "batch" dp_axis_name = "batch"
tp_axis_name = "model" tp_axis_name = "model"
...@@ -87,14 +87,14 @@ def self_fused_attn(qkv: jnp.ndarray, ...@@ -87,14 +87,14 @@ def self_fused_attn(qkv: jnp.ndarray,
jnp.reshape(x, new_shape) if x is not None else None jnp.reshape(x, new_shape) if x is not None else None
for x, new_shape in zip(inputs, sharding_meta.input_shapes)) for x, new_shape in zip(inputs, sharding_meta.input_shapes))
partial_self_fused_attn_max_512 = partial(_self_fused_attn_max_512, partial_self_fused_attn = partial(_self_fused_attn,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
output_ = xmap_runner(partial_self_fused_attn_max_512, sharding_meta.in_axes, output_ = xmap_runner(partial_self_fused_attn, sharding_meta.in_axes,
sharding_meta.out_axes[0], sharding_meta.axis_resources, inputs_) sharding_meta.out_axes[0], sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes[0]) output = jnp.reshape(output_, sharding_meta.output_shapes[0])
...@@ -103,61 +103,65 @@ def self_fused_attn(qkv: jnp.ndarray, ...@@ -103,61 +103,65 @@ def self_fused_attn(qkv: jnp.ndarray,
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8)) @partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _self_fused_attn_max_512(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
attn_mask_type: AttnMaskType, scaling_factor: float, scaling_factor: float, dropout_probability: float, is_training: bool):
dropout_probability: float, is_training: bool): output, _ = _self_fused_attn_fwd(qkv,
output, _ = _self_fused_attn_max_512_fwd(qkv, bias,
bias, mask,
mask, seed,
seed, attn_bias_type=attn_bias_type,
attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type,
attn_mask_type=attn_mask_type, scaling_factor=scaling_factor,
scaling_factor=scaling_factor, dropout_probability=dropout_probability,
dropout_probability=dropout_probability, is_training=is_training)
is_training=is_training)
return output return output
def _self_fused_attn_max_512_fwd(qkv, bias, mask, seed, attn_bias_type, attn_mask_type, def _self_fused_attn_fwd(qkv, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
scaling_factor, dropout_probability, is_training): dropout_probability, is_training):
seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32) seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
cu_seqlen = jnp.cumsum(seqlen) cu_seqlen = jnp.cumsum(seqlen)
cu_seqlen = jnp.hstack((0, cu_seqlen)) cu_seqlen = jnp.hstack((0, cu_seqlen))
output, softmax_aux = self_fused_attn_max_512_fwd(qkv, output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
bias, bias,
cu_seqlen, cu_seqlen,
seed, seed,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return output, (qkv, softmax_aux, cu_seqlen) return output, (qkv, softmax_aux, rng_state, output, cu_seqlen)
def _self_fused_attn_max_512_bwd(attn_bias_type, attn_mask_type, scaling_factor, def _self_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
dropout_probability, is_training, ctx, grad): is_training, ctx, grad):
qkv, softmax_aux, cu_seqlen = ctx qkv, softmax_aux, rng_state, output, cu_seqlen = ctx
doutput = grad doutput = grad
grad_qkv, grad_bias = self_fused_attn_max_512_bwd(qkv, grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
softmax_aux, softmax_aux,
doutput, rng_state,
cu_seqlen, output,
attn_bias_type=attn_bias_type.value, doutput,
attn_mask_type=attn_mask_type.value, cu_seqlen,
scaling_factor=scaling_factor, attn_bias_type=attn_bias_type.value,
dropout_probability=dropout_probability, attn_mask_type=attn_mask_type.value,
is_training=is_training) scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
grad_bias = None
return grad_qkv, grad_bias, None, None return grad_qkv, grad_bias, None, None
_self_fused_attn_max_512.defvjp(_self_fused_attn_max_512_fwd, _self_fused_attn_max_512_bwd) _self_fused_attn.defvjp(_self_fused_attn_fwd, _self_fused_attn_bwd)
def cross_fused_attn(q: jnp.ndarray, def cross_fused_attn(q: jnp.ndarray,
...@@ -174,18 +178,18 @@ def cross_fused_attn(q: jnp.ndarray, ...@@ -174,18 +178,18 @@ def cross_fused_attn(q: jnp.ndarray,
Cross multi-head attention wrapper Cross multi-head attention wrapper
""" """
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \ assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"Fused_attn_max_512 does not support row-split tensor parallelism currently." "cross_fused_attn does not support row-split tensor parallelism currently."
if sharding_type is ShardingType.SINGLE: if sharding_type is ShardingType.SINGLE:
output = _cross_fused_attn_max_512(q, output = _cross_fused_attn(q,
kv, kv,
mask, mask,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
else: else:
dp_axis_name = "batch" dp_axis_name = "batch"
tp_axis_name = "model" tp_axis_name = "model"
...@@ -203,14 +207,14 @@ def cross_fused_attn(q: jnp.ndarray, ...@@ -203,14 +207,14 @@ def cross_fused_attn(q: jnp.ndarray,
jnp.reshape(x, new_shape) if x is not None else None jnp.reshape(x, new_shape) if x is not None else None
for x, new_shape in zip(inputs, sharding_meta.input_shapes)) for x, new_shape in zip(inputs, sharding_meta.input_shapes))
partial_cross_fused_attn_max_512 = partial(_cross_fused_attn_max_512, partial_cross_fused_attn = partial(_cross_fused_attn,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
output_ = xmap_runner(partial_cross_fused_attn_max_512, sharding_meta.in_axes, output_ = xmap_runner(partial_cross_fused_attn, sharding_meta.in_axes,
sharding_meta.out_axes[0], sharding_meta.axis_resources, inputs_) sharding_meta.out_axes[0], sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes[0]) output = jnp.reshape(output_, sharding_meta.output_shapes[0])
...@@ -219,24 +223,24 @@ def cross_fused_attn(q: jnp.ndarray, ...@@ -219,24 +223,24 @@ def cross_fused_attn(q: jnp.ndarray,
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8)) @partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _cross_fused_attn_max_512(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray, def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool): scaling_factor: float, dropout_probability: float, is_training: bool):
output, _ = _cross_fused_attn_max_512_fwd(q, output, _ = _cross_fused_attn_fwd(q,
kv, kv,
mask, mask,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return output return output
def _cross_fused_attn_max_512_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor, def _cross_fused_attn_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training): dropout_probability, is_training):
q_seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32) q_seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
q_cu_seqlen = jnp.cumsum(q_seqlen) q_cu_seqlen = jnp.cumsum(q_seqlen)
...@@ -246,38 +250,38 @@ def _cross_fused_attn_max_512_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_t ...@@ -246,38 +250,38 @@ def _cross_fused_attn_max_512_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_t
kv_cu_seqlen = jnp.cumsum(kv_seqlen) kv_cu_seqlen = jnp.cumsum(kv_seqlen)
kv_cu_seqlen = jnp.hstack((0, kv_cu_seqlen)) kv_cu_seqlen = jnp.hstack((0, kv_cu_seqlen))
output, softmax_aux = cross_fused_attn_max_512_fwd(q, output, softmax_aux = cross_fused_attn_fwd(q,
kv, kv,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
seed, seed,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return output, (softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen) return output, (softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen)
def _cross_fused_attn_max_512_bwd(attn_bias_type, attn_mask_type, scaling_factor, def _cross_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
dropout_probability, is_training, ctx, grad): is_training, ctx, grad):
softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen = ctx softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen = ctx
doutput = grad doutput = grad
grad_q, grad_kv = cross_fused_attn_max_512_bwd(q, grad_q, grad_kv = cross_fused_attn_bwd(q,
kv, kv,
softmax_aux, softmax_aux,
doutput, doutput,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return grad_q, grad_kv, None, None return grad_q, grad_kv, None, None
_cross_fused_attn_max_512.defvjp(_cross_fused_attn_max_512_fwd, _cross_fused_attn_max_512_bwd) _cross_fused_attn.defvjp(_cross_fused_attn_fwd, _cross_fused_attn_bwd)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment