Unverified Commit bfaec644 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[C/JAX] Support more mask types for the arbitrary seqlen kernels and minor...


[C/JAX] Support more mask types for the arbitrary seqlen kernels and minor changes of JAX bias (#469)

* Move bias to float32
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enable varlen
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Increase neg infinity abs values
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enable varlen tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove unnecessary code
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix lint
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Support variable sequence length after cuDNN 8.9.6
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use unique_ptr instead of shared_ptr
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add a new mask type: PADDING_CAUSAL_MASK
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Support flash padding mask after 8.9.6
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance the Max512 handling for causal masking and add the related tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update the fused attn support lists
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove padding_aware from the caching
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix libtransformer.so issue
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Reduce the pad ratio tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix a bug with cuDNN 8.9.5
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Release backend resource after the module level unit test
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Clean the jax live arrays before running the unit tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix too-few-public-methods lint
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 64a3d1d5
...@@ -37,6 +37,16 @@ DTYPES = [jnp.bfloat16, jnp.float32] ...@@ -37,6 +37,16 @@ DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()
class TestFP8Dot: class TestFP8Dot:
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
import pytest import pytest
import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.core import ShapedArray from jax.core import ShapedArray
...@@ -31,6 +32,16 @@ DTYPE = [DType.kFloat32, DType.kFloat16, DType.kBFloat16] ...@@ -31,6 +32,16 @@ DTYPE = [DType.kFloat32, DType.kFloat16, DType.kBFloat16]
TRANSPOSE = [True, False] TRANSPOSE = [True, False]
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()
class TestGEMMShapeInfer: class TestGEMMShapeInfer:
@staticmethod @staticmethod
......
...@@ -21,12 +21,22 @@ from jax import value_and_grad, jit ...@@ -21,12 +21,22 @@ from jax import value_and_grad, jit
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine_jax import get_device_compute_capability from transformer_engine_jax import get_device_compute_capability # pylint: disable=wrong-import-order
# Type annotations # Type annotations
Array = jnp.ndarray Array = jnp.ndarray
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()
class Backend(Enum): class Backend(Enum):
""" """
Fused attn backend. Fused attn backend.
...@@ -52,6 +62,13 @@ CROSS_CASES = [(32, 128, 512, 16, 64)] ...@@ -52,6 +62,13 @@ CROSS_CASES = [(32, 128, 512, 16, 64)]
DTYPES = [jnp.bfloat16, jnp.float16] DTYPES = [jnp.bfloat16, jnp.float16]
def is_causal_mask(mask: AttnMaskType):
"""
Check if the mask is a causal mask
"""
return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]
def make_decoder_mask(tokens: Array) -> Array: def make_decoder_mask(tokens: Array) -> Array:
""" """
Create padded causal mask Create padded causal mask
...@@ -66,7 +83,7 @@ def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs): ...@@ -66,7 +83,7 @@ def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
Self attention with JAX native implementation Self attention with JAX native implementation
""" """
attn_mask_type = kwargs['attn_mask_type'] attn_mask_type = kwargs['attn_mask_type']
if attn_mask_type == AttnMaskType.CAUSAL_MASK: if is_causal_mask(attn_mask_type):
mask = make_decoder_mask(q_token) mask = make_decoder_mask(q_token)
else: else:
mask = make_attention_mask(q_token > 0, kv_token > 0) mask = make_attention_mask(q_token > 0, kv_token > 0)
...@@ -84,8 +101,8 @@ def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs): ...@@ -84,8 +101,8 @@ def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
deterministic=not kwargs['is_training'], deterministic=not kwargs['is_training'],
dropout_rate=kwargs['dropout_probability'], dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng, dropout_rng=dropout_rng,
dtype=qkv.dtype) dtype=jnp.float32)
return output return output.astype(qkv.dtype)
def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs): def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
...@@ -95,7 +112,7 @@ def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs): ...@@ -95,7 +112,7 @@ def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
assert q.dtype == kv.dtype assert q.dtype == kv.dtype
attn_mask_type = kwargs['attn_mask_type'] attn_mask_type = kwargs['attn_mask_type']
if attn_mask_type == AttnMaskType.CAUSAL_MASK: if is_causal_mask(attn_mask_type):
raise NotImplementedError raise NotImplementedError
mask = make_attention_mask(q_token > 0, kv_token > 0) mask = make_attention_mask(q_token > 0, kv_token > 0)
...@@ -112,15 +129,16 @@ def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs): ...@@ -112,15 +129,16 @@ def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
deterministic=not kwargs['is_training'], deterministic=not kwargs['is_training'],
dropout_rate=kwargs['dropout_probability'], dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng, dropout_rng=dropout_rng,
dtype=q.dtype) dtype=jnp.float32)
return output return output.astype(q.dtype)
def customcall_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs): def customcall_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
""" """
Self fused attention Self fused attention
""" """
if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK: attn_mask_type = kwargs['attn_mask_type']
if is_causal_mask(attn_mask_type):
mask = make_decoder_mask(q_token) mask = make_decoder_mask(q_token)
else: else:
mask = make_attention_mask(q_token > 0, kv_token > 0) mask = make_attention_mask(q_token > 0, kv_token > 0)
...@@ -137,7 +155,8 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs) ...@@ -137,7 +155,8 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)
""" """
assert q.dtype == kv.dtype assert q.dtype == kv.dtype
if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK: attn_mask_type = kwargs['attn_mask_type']
if is_causal_mask(attn_mask_type):
raise NotImplementedError raise NotImplementedError
mask = make_attention_mask(q_token > 0, kv_token > 0) mask = make_attention_mask(q_token > 0, kv_token > 0)
...@@ -149,32 +168,28 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs) ...@@ -149,32 +168,28 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)
@pytest.mark.parametrize('b, s, h, d', SELF_CASES) @pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS]) @pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) @pytest.mark.parametrize('attn_mask_type', [
AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK,
AttnMaskType.PADDING_CAUSAL_MASK
])
@pytest.mark.parametrize('dropout_probability', [0., 0.1]) @pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False]) @pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', [0, 0.3])
class TestSelfFusedAttn(): class TestSelfFusedAttn():
"""Tests for transformer_engine.jax.fused_attn.self_fused_attn""" """Tests for transformer_engine.jax.fused_attn.self_fused_attn"""
@staticmethod @staticmethod
def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probability, dtype, def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probability, dtype,
head_dim, pad_ratio): head_dim):
if (s > 512 or backend == Backend.Arbitrary) and pad_ratio != 0:
pytest.skip("Arbitrary seqlen backend hasn't support padded input.") assert isinstance(backend, Backend)
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type, if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
attn_mask_type, dropout_probability, s, s, head_dim): attn_mask_type, dropout_probability, s, s, head_dim):
pytest.skip("Unsupported inputs combination or device compute capability.") pytest.skip("Unsupported inputs combination or device compute capability.")
compute_capability = get_device_compute_capability(0)
if (backend == Backend.Max512
and not (compute_capability == 80 or compute_capability >= 90)):
pytest.skip("Unsupported compute capability for "
"fused attention with <=512 sequence length")
def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend, def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
dropout_probability, dtype, is_training, pad_ratio): dropout_probability, dtype, is_training):
"""Setup the test inputs""" """Setup the test inputs"""
self.__class__._check_inputs(s, self.__class__._check_inputs(s,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
...@@ -182,8 +197,13 @@ class TestSelfFusedAttn(): ...@@ -182,8 +197,13 @@ class TestSelfFusedAttn():
backend=backend, backend=backend,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
dtype=dtype, dtype=dtype,
head_dim=d, head_dim=d)
pad_ratio=pad_ratio)
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pad_ratio = 0.0
else:
pad_ratio = 0.3
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
...@@ -212,7 +232,7 @@ class TestSelfFusedAttn(): ...@@ -212,7 +232,7 @@ class TestSelfFusedAttn():
self.is_training = is_training self.is_training = is_training
def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, dropout_probability, def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, dropout_probability,
dtype, is_training, pad_ratio): dtype, is_training):
""" """
Test forward without using JIT Test forward without using JIT
""" """
...@@ -225,8 +245,7 @@ class TestSelfFusedAttn(): ...@@ -225,8 +245,7 @@ class TestSelfFusedAttn():
backend=backend, backend=backend,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
dtype=dtype, dtype=dtype,
is_training=is_training, is_training=is_training)
pad_ratio=pad_ratio)
primitive_out = customcall_self_fused_attn(self.qkv, primitive_out = customcall_self_fused_attn(self.qkv,
self.bias, self.bias,
...@@ -265,7 +284,7 @@ class TestSelfFusedAttn(): ...@@ -265,7 +284,7 @@ class TestSelfFusedAttn():
jnp.zeros_like(pri_invalid, jnp.float32)) jnp.zeros_like(pri_invalid, jnp.float32))
def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend,
dropout_probability, dtype, is_training, pad_ratio): dropout_probability, dtype, is_training):
""" """
Test forward, backward, and autodiff by jax.value_and_grad Test forward, backward, and autodiff by jax.value_and_grad
""" """
...@@ -281,13 +300,12 @@ class TestSelfFusedAttn(): ...@@ -281,13 +300,12 @@ class TestSelfFusedAttn():
backend=backend, backend=backend,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
dtype=dtype, dtype=dtype,
is_training=is_training, is_training=is_training)
pad_ratio=pad_ratio)
def grad_func(fused_attn_func, *args, **kwargs): def grad_func(fused_attn_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident # Gradient is small, use a gradient multiplier to amplify the graident
gradient_multiplier = 1000 if dtype == jnp.bfloat16 else 10000 gradient_multiplier = 1000 if dtype == jnp.bfloat16 else 10000
if attn_mask_type == AttnMaskType.CAUSAL_MASK: if is_causal_mask(attn_mask_type):
gradient_multiplier = gradient_multiplier / 10 gradient_multiplier = gradient_multiplier / 10
# Keep only valid result for the gradient # Keep only valid result for the gradient
# fused_attn output has shape (b, s, h, d) # fused_attn output has shape (b, s, h, d)
...@@ -333,15 +351,15 @@ class TestSelfFusedAttn(): ...@@ -333,15 +351,15 @@ class TestSelfFusedAttn():
rtol=1e-4, rtol=1e-4,
atol=1e-5) atol=1e-5)
valid_primitive_dqkv, invalid_primitive_dqkv = jnp.split(primitive_dqkv, (self.valid_len,), valid_primitive_dqkv, invalid_primitive_dqkv = \
axis=1) jnp.split(primitive_dqkv.astype(jnp.float32), (self.valid_len,), axis=1)
valid_reference_dqkv, invalid_reference_dqkv = jnp.split(reference_dqkv, (self.valid_len,), valid_reference_dqkv, invalid_reference_dqkv = \
axis=1) jnp.split(reference_dqkv.astype(jnp.float32), (self.valid_len,), axis=1)
valid_primitive_dq, valid_primitive_dk, valid_primitive_dv = jnp.split( valid_primitive_dq, valid_primitive_dk, valid_primitive_dv = \
valid_primitive_dqkv.astype(jnp.float32), 3, axis=2) jnp.split(valid_primitive_dqkv, 3, axis=2)
valid_reference_dq, valid_reference_dk, valid_reference_dv = jnp.split( valid_reference_dq, valid_reference_dk, valid_reference_dv = \
valid_reference_dqkv.astype(jnp.float32), 3, axis=2) jnp.split(valid_reference_dqkv, 3, axis=2)
np.testing.assert_allclose(valid_primitive_dq, valid_reference_dq, rtol=1e-4, atol=1e-5) np.testing.assert_allclose(valid_primitive_dq, valid_reference_dq, rtol=1e-4, atol=1e-5)
np.testing.assert_allclose(valid_primitive_dk, valid_reference_dk, rtol=1e-4, atol=1e-5) np.testing.assert_allclose(valid_primitive_dk, valid_reference_dk, rtol=1e-4, atol=1e-5)
...@@ -482,9 +500,7 @@ class TestCrossFusedAttn(): ...@@ -482,9 +500,7 @@ class TestCrossFusedAttn():
def grad_func(fused_attn_func, *args, **kwargs): def grad_func(fused_attn_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident # Gradient is small, use a gradient multiplier to amplify the graident
gradient_multiplier = 10000 gradient_multiplier = 1e4
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
gradient_multiplier = gradient_multiplier / 10
# Keep only valid result for the gradient # Keep only valid result for the gradient
# fused_attn output has shape (b, s_q, h, d) # fused_attn output has shape (b, s_q, h, d)
valid_fused_attn_ret, _ = jnp.split(fused_attn_func(*args, **kwargs), valid_fused_attn_ret, _ = jnp.split(fused_attn_func(*args, **kwargs),
......
...@@ -19,6 +19,16 @@ from utils import EncoderLayer as RefEncoderLayer ...@@ -19,6 +19,16 @@ from utils import EncoderLayer as RefEncoderLayer
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()
def loss_fn(diff_xs, no_diff_xs, params, others, model, rngs): def loss_fn(diff_xs, no_diff_xs, params, others, model, rngs):
output = model.apply({"params": params, **others}, *diff_xs, *no_diff_xs, rngs=rngs) output = model.apply({"params": params, **others}, *diff_xs, *no_diff_xs, rngs=rngs)
return jnp.mean(output) return jnp.mean(output)
......
...@@ -38,6 +38,16 @@ ENABLE_FP8 = [False, True] ...@@ -38,6 +38,16 @@ ENABLE_FP8 = [False, True]
FP8_FORMATS = [Format.E4M3, Format.HYBRID] FP8_FORMATS = [Format.E4M3, Format.HYBRID]
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
for key in ref_fd: for key in ref_fd:
assert key in test_fd, \ assert key in test_fd, \
......
...@@ -87,6 +87,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -87,6 +87,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
const int sm_arch_ = cuda::sm_arch(device_id); const int sm_arch_ = cuda::sm_arch(device_id);
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
auto cudnn_runtime_version = cudnnGetVersion();
if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2) if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)
&& (sm_arch_ >= 90) && (sm_arch_ >= 90)
&& (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q == max_seqlen_kv)
...@@ -111,6 +112,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -111,6 +112,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
&& ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
|| (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS))
&& ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) && ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
...@@ -131,7 +133,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -131,7 +133,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
&& (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q == max_seqlen_kv)
&& ((head_dim == 64) || (head_dim == 128)) && ((head_dim == 64) || (head_dim == 128))
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) && ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| ((cudnn_runtime_version >= 8906) &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)))
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) && ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) || (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD))) { || (qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD))) {
......
...@@ -79,7 +79,7 @@ createScale(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, ...@@ -79,7 +79,7 @@ createScale(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
static cudnn_frontend::Tensor static cudnn_frontend::Tensor
createQKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, createQKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType, bool padding_aware, NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops) { std::vector<cudnn_frontend::Operation>* ops) {
// Creates the necessary tensor descriptors // Creates the necessary tensor descriptors
int64_t q_dim[4] = {b, h, s_q, d}; int64_t q_dim[4] = {b, h, s_q, d};
...@@ -95,6 +95,9 @@ createQKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, ...@@ -95,6 +95,9 @@ createQKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
int64_t s_stride[4]; int64_t s_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix);
int64_t seqlen_dim[4] = {b, 1, 1, 1};
int64_t seqlen_stride[4] = {1, 1, 1, 1};
auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false); auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false);
auto kTransposeTensor = tensor_create( auto kTransposeTensor = tensor_create(
tensorType, K_ID, k_dim, k_stride, false, false); // is virtual tensorType, K_ID, k_dim, k_stride, false, false); // is virtual
...@@ -105,21 +108,150 @@ createQKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, ...@@ -105,21 +108,150 @@ createQKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
// Define the matmul 1 desc // Define the matmul 1 desc
auto matmul_1_Desc = cudnn_frontend::MatMulDescBuilder() auto matmul_1_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT) .setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build(); .build();
auto seqlenQTensor = tensor_create(
CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
auto seqlenKTensor = tensor_create(
CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
// Create a matmul 1 node // Create a matmul 1 node
auto matmul_op1 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) auto&& matmul_op_builder =
.setaMatDesc(qTensor) cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
.setbMatDesc(kTransposeTensor)
.setcMatDesc(sTensor) matmul_op_builder.setaMatDesc(qTensor)
.setmatmulDesc(matmul_1_Desc) .setbMatDesc(kTransposeTensor)
.build(); .setcMatDesc(sTensor)
.setmatmulDesc(matmul_1_Desc);
if (padding_aware) {
matmul_op_builder.setmOverrideDesc(seqlenQTensor).setnOverrideDesc(seqlenKTensor);
}
auto matmul_op1 = matmul_op_builder.build();
ops->push_back(std::move(matmul_op1)); ops->push_back(std::move(matmul_op1));
return sTensor; return sTensor;
} }
static cudnn_frontend::Tensor
createPaddingMask(int64_t b,
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
NVTE_QKV_Layout layout,
cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& prevBlockOutputTensor) {
CUDNN_FRONTEND_UNUSED(d);
CUDNN_FRONTEND_UNUSED(layout);
CUDNN_FRONTEND_UNUSED(tensorType);
NVTE_CHECK(ops->size() != 0, "Padding Mask constructed incorrectly as the first one");
// subtraction output
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t maskVal_dim[4] = {1, 1, 1, 1};
int64_t maskVal_stride[4] = {1, 1, 1, 1};
int64_t seqlen_dim[4] = {b, 1, 1, 1};
int64_t seqlen_stride[4] = {1, 1, 1, 1};
// mask value to put in the masked pixels
auto maskValTensor = tensor_create(
CUDNN_DATA_FLOAT, MASK_VAL_ID, maskVal_dim, maskVal_stride, false, true);
auto seqlenQTensor = tensor_create(
CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
auto seqlenKTensor = tensor_create(
CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
// gen index row output
auto rowIndexTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 300, afterBMM1_dim, afterBMM1_stride, true, false);
// gen index column output
auto columnIndexTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 301, afterBMM1_dim, afterBMM1_stride, true, false);
// less than row output
auto lessThanRowTensor = tensor_create(
CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 302, afterBMM1_dim, afterBMM1_stride, true, false);
// less than column output
auto lessThanColTensor = tensor_create(
CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 303, afterBMM1_dim, afterBMM1_stride, true, false);
// padding mask (lessthanRow && lessthanCol)
auto paddingMaskTensor = tensor_create(
CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 304, afterBMM1_dim, afterBMM1_stride, true, false);
// output after masking
auto maskOutputTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 305, afterBMM1_dim, afterBMM1_stride, true, false);
// Define the gen index for row descriptor
auto genIndexRowDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setAxis(2)
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a gen index Node.
auto genIndexRow_op = unary_pw_op_create(
prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc);
// Define the gen index for row descriptor
auto genIndexColumnDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setAxis(3)
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a gen index Node.
auto genIndexColumn_op = unary_pw_op_create(
prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc);
// Define the less than comparison for row descriptor
auto lessThanRowDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT);
// Create a less than comparison for row Node.
auto lessThanRow_op = binary_pw_op_create(
rowIndexTensor, seqlenQTensor, lessThanRowTensor, lessThanRowDesc);
// Define the less than comparison for column descriptor
auto lessThanColDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT);
// Create a less than comparison for col Node.
auto lessThanCol_op = binary_pw_op_create(
columnIndexTensor, seqlenKTensor, lessThanColTensor, lessThanColDesc);
// Define the less than comparison for column descriptor
auto paddingMaskAndDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_LOGICAL_AND);
// Create a and node for combining lessThanRow and lessThanCol
auto paddingMaskAnd_op = binary_pw_op_create(
lessThanRowTensor, lessThanColTensor, paddingMaskTensor, paddingMaskAndDesc);
/////////////////// Apply the mask //////////////////////////
// Define the binary select to perform masking descriptor
auto maskDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT);
// Create a binary select Node.
auto mask_op = ternary_pw_op_create(
prevBlockOutputTensor, maskValTensor, paddingMaskTensor, maskOutputTensor, maskDesc);
ops->push_back(std::move(genIndexRow_op));
ops->push_back(std::move(genIndexColumn_op));
ops->push_back(std::move(lessThanRow_op));
ops->push_back(std::move(lessThanCol_op));
ops->push_back(std::move(paddingMaskAnd_op));
ops->push_back(std::move(mask_op));
return maskOutputTensor;
}
static cudnn_frontend::Tensor static cudnn_frontend::Tensor
createCausalMask(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, createCausalMask(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType, NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
...@@ -502,7 +634,7 @@ createDropoutBackward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d ...@@ -502,7 +634,7 @@ createDropoutBackward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d
static void static void
createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType, bool padding_aware, NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops, std::vector<cudnn_frontend::Operation>* ops,
cudnn_frontend::Tensor const &afterScaleDropoutTensor) { cudnn_frontend::Tensor const &afterScaleDropoutTensor) {
NVTE_CHECK(ops->size() != 0, "BMM2 op constructed incorrectly as the first one"); NVTE_CHECK(ops->size() != 0, "BMM2 op constructed incorrectly as the first one");
...@@ -515,6 +647,14 @@ createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, ...@@ -515,6 +647,14 @@ createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
int64_t o_stride[4]; int64_t o_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
int64_t seqlen_dim[4] = {b, 1, 1, 1};
int64_t seqlen_stride[4] = {1, 1, 1, 1};
auto seqlenQTensor = tensor_create(
CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
auto seqlenKTensor = tensor_create(
CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
auto vTensor = tensor_create(tensorType, V_ID, v_dim, v_stride, false, false); auto vTensor = tensor_create(tensorType, V_ID, v_dim, v_stride, false, false);
// second GEMM output // second GEMM output
auto oTensor = tensor_create(tensorType, O_ID, o_dim, o_stride, false, false); auto oTensor = tensor_create(tensorType, O_ID, o_dim, o_stride, false, false);
...@@ -522,15 +662,23 @@ createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, ...@@ -522,15 +662,23 @@ createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
// Define the matmul 2 desc // Define the matmul 2 desc
auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder() auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT) .setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build(); .build();
// Create a matmul 2 node // Create a matmul 2 node
auto matmul_op2 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) auto&& matmul_op_builder =
.setaMatDesc(afterScaleDropoutTensor) cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
.setbMatDesc(vTensor)
.setcMatDesc(oTensor) matmul_op_builder.setaMatDesc(afterScaleDropoutTensor)
.setmatmulDesc(matmul_2_Desc) .setbMatDesc(vTensor)
.build(); .setcMatDesc(oTensor)
.setmatmulDesc(matmul_2_Desc);
if (padding_aware) {
matmul_op_builder.setmOverrideDesc(seqlenQTensor).setkOverrideDesc(seqlenKTensor);
}
auto matmul_op2 = matmul_op_builder.build();
ops->push_back(std::move(matmul_op2)); ops->push_back(std::move(matmul_op2));
} }
...@@ -538,9 +686,10 @@ createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, ...@@ -538,9 +686,10 @@ createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
void fused_attn_arbitrary_seqlen_fwd_impl( void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
bool is_training, float scaling_factor, float dropout_probability, bool is_training, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout, NVTE_QKV_Layout layout, NVTE_Mask_Type mask_type,
void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrQ, void *devPtrK, void *devPtrV,
void *devPtrSoftmaxStats, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrO,
void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV,
void* devPtrDropoutSeed, void* devPtrDropoutOffset, void* devPtrDropoutSeed, void* devPtrDropoutOffset,
cudnnDataType_t tensorType, cudnnDataType_t tensorType,
void *workspace, size_t *workspace_size, void *workspace, size_t *workspace_size,
...@@ -552,12 +701,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -552,12 +701,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
dropout_probability = 0.0f; dropout_probability = 0.0f;
} }
// also known as variable_sequence_length
bool padding_aware = (mask_type == NVTE_PADDING_MASK) ||
(mask_type == NVTE_PADDING_CAUSAL_MASK);
FADescriptor descriptor{b, h, FADescriptor descriptor{b, h,
s_q, s_kv, s_q, s_kv,
d, scaling_factor, d, scaling_factor,
is_training, dropout_probability, is_training, dropout_probability,
layout, NVTE_Bias_Type::NVTE_NO_BIAS, layout, NVTE_Bias_Type::NVTE_NO_BIAS,
NVTE_Mask_Type::NVTE_CAUSAL_MASK, tensorType, mask_type, tensorType,
false}; false};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>; using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
...@@ -577,15 +730,24 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -577,15 +730,24 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::vector<cudnn_frontend::Operation> ops; std::vector<cudnn_frontend::Operation> ops;
// Q * K^T // Q * K^T
auto sTensor = createQKBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops); auto sTensor = createQKBMM(
b, h, s_q, s_kv, d, padding_aware, layout, tensorType, &ops);
// Q * K^T * bmmScale // Q * K^T * bmmScale
auto sScaleTensor = createScale( auto sScaleTensor = createScale(
b, h, s_q, s_kv, d, layout, CUDNN_DATA_FLOAT, sTensor, &ops); b, h, s_q, s_kv, d, layout, CUDNN_DATA_FLOAT, sTensor, &ops);
// Causual mask auto& sAfterMaskTensor = sScaleTensor;
auto sAfterMaskTensor = createCausalMask(
b, h, s_q, s_kv, d, layout, tensorType, &ops, sScaleTensor); if (mask_type == NVTE_CAUSAL_MASK || mask_type == NVTE_PADDING_CAUSAL_MASK) {
sAfterMaskTensor = createCausalMask(
b, h, s_q, s_kv, d, layout, tensorType, &ops, sScaleTensor);
}
if (padding_aware) {
sAfterMaskTensor = createPaddingMask(
b, h, s_q, s_kv, d, layout, tensorType, &ops, sAfterMaskTensor);
}
NVTE_CHECK(dropout_probability != 1.0f, NVTE_CHECK(dropout_probability != 1.0f,
"Dropout probability cannot be 1.0"); "Dropout probability cannot be 1.0");
...@@ -597,7 +759,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -597,7 +759,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
auto dropout_output = createDropoutForward( auto dropout_output = createDropoutForward(
b, h, s_q, s_kv, d, b, h, s_q, s_kv, d,
dropout_probability, tensorType, &ops, softmax_output); dropout_probability, tensorType, &ops, softmax_output);
createSVBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops, dropout_output); createSVBMM(b, h, s_q, s_kv, d, padding_aware,
layout, tensorType, &ops, dropout_output);
for (unsigned int i = 0; i < ops.size(); i++) { for (unsigned int i = 0; i < ops.size(); i++) {
all_ops.push_back(&ops[i]); all_ops.push_back(&ops[i]);
...@@ -636,13 +799,29 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -636,13 +799,29 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
// Exit to request upper level API to allocate memory if needed // Exit to request upper level API to allocate memory if needed
if (workspace == nullptr) { if (workspace == nullptr) {
*workspace_size = plan_workspace_size; size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
*workspace_size = plan_workspace_size + actual_seqlen_workspace_size;
return; return;
} }
// Prepare actual seqlen
constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devActualSeqlenK = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
if (padding_aware) {
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, static_cast<const int32_t *>(devPtrCuSeqlenQ),
static_cast<const int32_t *>(devPtrCuSeqlenKV),
static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenK));
NVTE_CHECK_CUDA(cudaGetLastError());
}
std::set<std::pair<uint64_t, void*>> data_ptrs; std::set<std::pair<uint64_t, void*>> data_ptrs;
// Add all the data pointers to be used in the variant pack // Add all the data pointers to be used in the variant pack
float negInfinity = -1.0E+10f; float negInfinity = -1.0E+30f;
float scale_dropout = 1.0f/(1.0f - dropout_probability); float scale_dropout = 1.0f/(1.0f - dropout_probability);
data_ptrs.insert(std::pair<uint64_t, void*>(Q_ID, devPtrQ)); data_ptrs.insert(std::pair<uint64_t, void*>(Q_ID, devPtrQ));
...@@ -655,6 +834,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -655,6 +834,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
data_ptrs.insert(std::pair<uint64_t, void*>(D_OFFSET_ID, devPtrDropoutOffset)); data_ptrs.insert(std::pair<uint64_t, void*>(D_OFFSET_ID, devPtrDropoutOffset));
data_ptrs.insert(std::pair<uint64_t, void*>(D_CONST_ID, &scale_dropout)); data_ptrs.insert(std::pair<uint64_t, void*>(D_CONST_ID, &scale_dropout));
if (padding_aware) {
data_ptrs.insert(std::pair<uint64_t, void*>(Q_SEQLEN_ID, devActualSeqlenQ));
data_ptrs.insert(std::pair<uint64_t, void*>(K_SEQLEN_ID, devActualSeqlenK));
}
// If training mode, we write out softmax stats // If training mode, we write out softmax stats
if (is_training) { if (is_training) {
data_ptrs.insert(std::pair<uint64_t, void*>(S_STATS_ID, devPtrSoftmaxStats)); data_ptrs.insert(std::pair<uint64_t, void*>(S_STATS_ID, devPtrSoftmaxStats));
...@@ -675,21 +859,26 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -675,21 +859,26 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
void fused_attn_arbitrary_seqlen_bwd_impl( void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
void* devPtrQ, void* devPtrKTranspose, void* devPtrVTranspose, NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrKTranspose,
void* devPtrO, void* devPtrSoftmaxStats, void* devPtrVTranspose, void* devPtrO, void* devPtrSoftmaxStats,
void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO,
void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV,
void* devPtrDropoutSeed, void* devPtrDropoutOffset, void* devPtrDropoutSeed, void* devPtrDropoutOffset,
cudnnDataType_t tensorType, void *workspace, size_t *workspace_size, cudnnDataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle, bool use_workspace_opt) { cudaStream_t stream, cudnnHandle_t handle, bool use_workspace_opt) {
try { try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
// also known as variable_sequence_length
bool padding_aware = (mask_type == NVTE_PADDING_MASK) ||
(mask_type == NVTE_PADDING_CAUSAL_MASK);
FADescriptor descriptor{b, h, FADescriptor descriptor{b, h,
s_q, s_kv, s_q, s_kv,
d, scaling_factor, d, scaling_factor,
true, dropout_probability, true, dropout_probability,
layout, NVTE_Bias_Type::NVTE_NO_BIAS, layout, NVTE_Bias_Type::NVTE_NO_BIAS,
NVTE_Mask_Type::NVTE_CAUSAL_MASK, tensorType, mask_type, tensorType,
use_workspace_opt}; use_workspace_opt};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>; using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
...@@ -747,9 +936,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -747,9 +936,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
generateMatrixStrides(b, h, s_q, s_kv, d, dqAccum_stride, generateMatrixStrides(b, h, s_q, s_kv, d, dqAccum_stride,
layout, NVTE_QKV_Matrix::NVTE_O_Matrix); layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
int64_t seqlen_dim[4] = {b, 1, 1, 1};
int64_t seqlen_stride[4] = {1, 1, 1, 1};
int64_t scale_dim[4] = {1, 1, 1, 1}; int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1}; int64_t scale_stride[4] = {1, 1, 1, 1};
auto seqlenQTensor = tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim,
seqlen_stride, false, false);
auto seqlenKTensor = tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim,
seqlen_stride, false, false);
/******************************************************************************* /*******************************************************************************
* Dot product dO * O */ * Dot product dO * O */
...@@ -823,15 +1020,22 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -823,15 +1020,22 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
// matmul to calculate dvTensor // matmul to calculate dvTensor
auto matmul_0_Desc = cudnn_frontend::MatMulDescBuilder() auto matmul_0_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT) .setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build(); .build();
auto matmul_op0 = cudnn_frontend::OperationBuilder( auto&& matmul_op_builder =
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
.setaMatDesc(qTensor)
.setbMatDesc(kTransposeTensor) matmul_op_builder.setaMatDesc(qTensor)
.setcMatDesc(pTensor) .setbMatDesc(kTransposeTensor)
.setmatmulDesc(matmul_0_Desc) .setcMatDesc(pTensor)
.build(); .setmatmulDesc(matmul_0_Desc);
if (padding_aware) {
matmul_op_builder.setmOverrideDesc(seqlenQTensor).setnOverrideDesc(seqlenKTensor);
}
auto matmul_op0 = matmul_op_builder.build();
ops.push_back(std::move(matmul_op0)); ops.push_back(std::move(matmul_op0));
...@@ -851,8 +1055,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -851,8 +1055,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
/******************************************************************************* /*******************************************************************************
* Causal masking -> pAfterMaskTensor */ * Causal masking -> pAfterMaskTensor */
auto pAfterMaskTensor = createCausalMask( auto& pAfterMaskTensor = pAfterScaleTensor;
b, h, s_q, s_kv, d, layout, tensorType, &ops, pAfterScaleTensor);
if (mask_type == NVTE_CAUSAL_MASK || mask_type == NVTE_PADDING_CAUSAL_MASK) {
pAfterMaskTensor = createCausalMask(
b, h, s_q, s_kv, d, layout, tensorType, &ops, pAfterScaleTensor);
}
if (padding_aware) {
pAfterMaskTensor = createPaddingMask(
b, h, s_q, s_kv, d, layout, tensorType, &ops, pAfterMaskTensor);
}
/******************************************************************************* /*******************************************************************************
* pAfterMaskTensor - softmaxStats -> pAfterSubtract */ * pAfterMaskTensor - softmaxStats -> pAfterSubtract */
...@@ -930,15 +1143,22 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -930,15 +1143,22 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto matmul_1_Desc = cudnn_frontend::MatMulDescBuilder() auto matmul_1_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT) .setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build(); .build();
auto matmul_op1 = cudnn_frontend::OperationBuilder( auto&& matmul_op1_builder =
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
.setaMatDesc(sTransposeTensor)
.setbMatDesc(dOTensor) matmul_op1_builder.setaMatDesc(sTransposeTensor)
.setcMatDesc(dVTensor) .setbMatDesc(dOTensor)
.setmatmulDesc(matmul_1_Desc) .setcMatDesc(dVTensor)
.build(); .setmatmulDesc(matmul_1_Desc);
if (padding_aware) {
matmul_op1_builder.setmOverrideDesc(seqlenKTensor).setkOverrideDesc(seqlenQTensor);
}
auto matmul_op1 = matmul_op1_builder.build();
ops.push_back(std::move(matmul_op1)); ops.push_back(std::move(matmul_op1));
...@@ -954,15 +1174,22 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -954,15 +1174,22 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder() auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT) .setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build(); .build();
auto matmul_op2 = cudnn_frontend::OperationBuilder( auto&& matmul_op2_builder =
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
.setaMatDesc(dOTensor)
.setbMatDesc(vTransposeTensor) matmul_op2_builder.setaMatDesc(dOTensor)
.setcMatDesc(dSTensor) .setbMatDesc(vTransposeTensor)
.setmatmulDesc(matmul_2_Desc) .setcMatDesc(dSTensor)
.build(); .setmatmulDesc(matmul_2_Desc);
if (padding_aware) {
matmul_op2_builder.setmOverrideDesc(seqlenQTensor).setnOverrideDesc(seqlenKTensor);
}
auto matmul_op2 = matmul_op2_builder.build();
ops.push_back(std::move(matmul_op2)); ops.push_back(std::move(matmul_op2));
...@@ -1059,30 +1286,30 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -1059,30 +1286,30 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto matmul_3_Desc = cudnn_frontend::MatMulDescBuilder() auto matmul_3_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT) .setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build(); .build();
if (!use_workspace_opt) { auto&& matmul_op3_builder =
auto matmul_op3 = cudnn_frontend::OperationBuilder( cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(dPScaledTensor) matmul_op3_builder.setaMatDesc(dPScaledTensor)
.setbMatDesc(kTensor) .setbMatDesc(kTensor)
.setcMatDesc(dqAccumTensor) .setmatmulDesc(matmul_3_Desc);
.setmatmulDesc(matmul_3_Desc)
.build(); if (use_workspace_opt) {
matmul_op3_builder.setcMatDesc(dQTensor);
ops.push_back(std::move(matmul_op3));
} else { } else {
auto matmul_op3 = cudnn_frontend::OperationBuilder( matmul_op3_builder.setcMatDesc(dqAccumTensor);
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) }
.setaMatDesc(dPScaledTensor)
.setbMatDesc(kTensor) if (padding_aware) {
.setcMatDesc(dQTensor) matmul_op3_builder.setmOverrideDesc(seqlenQTensor).setkOverrideDesc(seqlenKTensor);
.setmatmulDesc(matmul_3_Desc)
.build();
ops.push_back(std::move(matmul_op3));
} }
auto matmul_op3 = matmul_op3_builder.build();
ops.push_back(std::move(matmul_op3));
/******************************************************************************* /*******************************************************************************
* dP.T @ Q -> dK */ * dP.T @ Q -> dK */
...@@ -1098,14 +1325,22 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -1098,14 +1325,22 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto matmul_4_Desc = cudnn_frontend::MatMulDescBuilder() auto matmul_4_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT) .setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0.0f)
.build(); .build();
auto matmul_op4 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) auto&& matmul_op4_builder =
.setaMatDesc(dPTransposeTensor) cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
.setbMatDesc(qTensor)
.setcMatDesc(dKTensor) matmul_op4_builder.setaMatDesc(dPTransposeTensor)
.setmatmulDesc(matmul_4_Desc) .setbMatDesc(qTensor)
.build(); .setcMatDesc(dKTensor)
.setmatmulDesc(matmul_4_Desc);
if (padding_aware) {
matmul_op4_builder.setmOverrideDesc(seqlenKTensor).setkOverrideDesc(seqlenQTensor);
}
auto matmul_op4 = matmul_op4_builder.build();
ops.push_back(std::move(matmul_op4)); ops.push_back(std::move(matmul_op4));
...@@ -1153,29 +1388,36 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -1153,29 +1388,36 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
// Exit to request upper level API to allocate memory if needed // Exit to request upper level API to allocate memory if needed
size_t softmaxSum_workspace_size = b * h * s_q * sizeof(float); size_t softmaxSum_workspace_size = b * h * s_q * sizeof(float);
size_t dqAccum_workspace_size = b * s_q * h * d * sizeof(float); size_t dqAccum_workspace_size = use_workspace_opt ? 0 : b * s_q * h * d * sizeof(float);
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
if (workspace == nullptr) { if (workspace == nullptr) {
if (use_workspace_opt) { *workspace_size = plan_workspace_size + softmaxSum_workspace_size
*workspace_size = plan_workspace_size + softmaxSum_workspace_size; + dqAccum_workspace_size + actual_seqlen_workspace_size;
} else {
*workspace_size = plan_workspace_size + softmaxSum_workspace_size
+ dqAccum_workspace_size;
}
return; return;
} }
void *devPtrSoftmaxSum = static_cast<int8_t *>(workspace) + plan_workspace_size; void *devPtrSoftmaxSum = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devPtrdQAccumulator = nullptr; void *devPtrdQAccumulator = static_cast<int8_t *>(devPtrSoftmaxSum)
if (!use_workspace_opt) {
devPtrdQAccumulator = static_cast<int8_t *>(devPtrSoftmaxSum)
+ softmaxSum_workspace_size; + softmaxSum_workspace_size;
if (!use_workspace_opt) {
NVTE_CHECK_CUDA(cudaMemsetAsync( NVTE_CHECK_CUDA(cudaMemsetAsync(
devPtrdQAccumulator, 0, dqAccum_workspace_size, stream)); devPtrdQAccumulator, 0, dqAccum_workspace_size, stream));
} }
constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
void *devActualSeqlenQ =
static_cast<int8_t *>(devPtrdQAccumulator) + dqAccum_workspace_size;
void *devActualSeqlenK = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, static_cast<const int32_t *>(devPtrCuSeqlenQ),
static_cast<const int32_t *>(devPtrCuSeqlenKV),
static_cast<int32_t *>(devActualSeqlenQ), static_cast<int32_t *>(devActualSeqlenK));
NVTE_CHECK_CUDA(cudaGetLastError());
std::set<std::pair<uint64_t, void *>> data_ptrs; std::set<std::pair<uint64_t, void *>> data_ptrs;
// add all the data pointers to be used in the variant pack // add all the data pointers to be used in the variant pack
float negInfinity = -1.0E+10f; float negInfinity = -1.0E+31f;
float scale_dropout = 1.0f/(1.0f - dropout_probability); float scale_dropout = 1.0f/(1.0f - dropout_probability);
data_ptrs.insert(std::pair<uint64_t, void*>(dQ_ID, devPtrdQ)); data_ptrs.insert(std::pair<uint64_t, void*>(dQ_ID, devPtrdQ));
if (!use_workspace_opt) { if (!use_workspace_opt) {
...@@ -1194,6 +1436,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -1194,6 +1436,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
data_ptrs.insert(std::pair<uint64_t, void*>(D_SEED_ID, devPtrDropoutSeed)); data_ptrs.insert(std::pair<uint64_t, void*>(D_SEED_ID, devPtrDropoutSeed));
data_ptrs.insert(std::pair<uint64_t, void*>(D_OFFSET_ID, devPtrDropoutOffset)); data_ptrs.insert(std::pair<uint64_t, void*>(D_OFFSET_ID, devPtrDropoutOffset));
data_ptrs.insert(std::pair<uint64_t, void*>(MASK_VAL_ID, &negInfinity)); data_ptrs.insert(std::pair<uint64_t, void*>(MASK_VAL_ID, &negInfinity));
if (padding_aware) {
data_ptrs.insert(std::pair<uint64_t, void *>(Q_SEQLEN_ID, devActualSeqlenQ));
data_ptrs.insert(std::pair<uint64_t, void *>(K_SEQLEN_ID, devActualSeqlenK));
}
float scaleProb = 1.0f - dropout_probability; float scaleProb = 1.0f - dropout_probability;
data_ptrs.insert(std::pair<uint64_t, void*>(D_CONST_ID, &scale_dropout)); data_ptrs.insert(std::pair<uint64_t, void*>(D_CONST_ID, &scale_dropout));
...@@ -1254,6 +1500,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -1254,6 +1500,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
} }
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void* devPtrDropoutSeed = rng_state->data.dptr; void* devPtrDropoutSeed = rng_state->data.dptr;
void* devPtrDropoutOffset = reinterpret_cast<void *>( void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1); reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
...@@ -1262,8 +1510,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -1262,8 +1510,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, fused_attn_arbitrary_seqlen_fwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim,
is_training, attn_scale, p_dropout, qkv_layout, is_training, attn_scale, p_dropout, qkv_layout, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO,
devPtrCuSeqlens, devPtrCuSeqlens,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type), get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle); workspace->data.dptr, &workspace_size, stream, handle);
...@@ -1318,6 +1567,8 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, ...@@ -1318,6 +1567,8 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
void* devPtrDropoutOffset = reinterpret_cast<void *>( void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1); reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
const auto qkv_type = input_QKV->data.dtype; const auto qkv_type = input_QKV->data.dtype;
size_t workspace_size = 0; size_t workspace_size = 0;
...@@ -1349,9 +1600,10 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, ...@@ -1349,9 +1600,10 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
#endif #endif
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim,
attn_scale, p_dropout, qkv_layout, attn_scale, p_dropout, qkv_layout, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
devPtrCuSeqlens, devPtrCuSeqlens,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(qkv_type), workspace->data.dptr, get_cudnn_dtype(qkv_type), workspace->data.dptr,
&workspace_size, stream, handle, use_workspace_opt); &workspace_size, stream, handle, use_workspace_opt);
...@@ -1412,11 +1664,15 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1412,11 +1664,15 @@ void fused_attn_arbitrary_seqlen_fwd(
void* devPtrDropoutOffset = reinterpret_cast<void *>( void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1); reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_head, max_seqlen_q, max_seqlen_kv, head_dim, fused_attn_arbitrary_seqlen_fwd_impl(batch, num_head, max_seqlen_q, max_seqlen_kv, head_dim,
is_training, attn_scale, p_dropout, qkv_layout, is_training, attn_scale, p_dropout, qkv_layout, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO,
devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type), get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle); workspace->data.dptr, &workspace_size, stream, handle);
...@@ -1467,6 +1723,9 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t m ...@@ -1467,6 +1723,9 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t m
void* devPtrDropoutOffset = reinterpret_cast<void *>( void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1); reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
size_t workspace_size = 0; size_t workspace_size = 0;
bool use_workspace_opt = false; bool use_workspace_opt = false;
...@@ -1497,9 +1756,10 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t m ...@@ -1497,9 +1756,10 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t m
#endif #endif
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen_q, max_seqlen_kv, head_dim, fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen_q, max_seqlen_kv, head_dim,
attn_scale, p_dropout, qkv_layout, attn_scale, p_dropout, qkv_layout, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type), workspace->data.dptr, get_cudnn_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle, use_workspace_opt); &workspace_size, stream, handle, use_workspace_opt);
......
...@@ -298,7 +298,8 @@ static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int6 ...@@ -298,7 +298,8 @@ static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int6
/////////////////// Apply the mask ////////////////////////// /////////////////// Apply the mask //////////////////////////
auto maskTensor = (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) auto maskTensor = (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
? std::move(causalMaskTensor) ? std::move(causalMaskTensor)
: std::move(paddingMaskTensor); : std::move(paddingMaskTensor);
...@@ -314,7 +315,8 @@ static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int6 ...@@ -314,7 +315,8 @@ static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int6
ops.push_back(std::move(lessThanRow_op)); ops.push_back(std::move(lessThanRow_op));
ops.push_back(std::move(lessThanCol_op)); ops.push_back(std::move(lessThanCol_op));
ops.push_back(std::move(paddingMaskAnd_op)); ops.push_back(std::move(paddingMaskAnd_op));
if (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) { if (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) {
ops.push_back(std::move(rowGreaterCol_op)); ops.push_back(std::move(rowGreaterCol_op));
ops.push_back(std::move(causalMaskAnd_op)); ops.push_back(std::move(causalMaskAnd_op));
} }
...@@ -680,7 +682,8 @@ void fused_attn_max_512_fwd_impl( ...@@ -680,7 +682,8 @@ void fused_attn_max_512_fwd_impl(
// WAR: causal_mask without bias needs memset the S buffer // WAR: causal_mask without bias needs memset the S buffer
// inference mode doesn't need the S auxiliary // inference mode doesn't need the S auxiliary
auto zero_s = (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) || auto zero_s = (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) ||
(mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) && is_training; (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) && is_training;
std::shared_ptr<cudnn_frontend::Tensor> maskInput; std::shared_ptr<cudnn_frontend::Tensor> maskInput;
auto bmm1_output = createBMM1(b, h, s_q, s_kv, d, layout, tensorType, zero_s, ops); auto bmm1_output = createBMM1(b, h, s_q, s_kv, d, layout, tensorType, zero_s, ops);
......
...@@ -146,6 +146,8 @@ enum NVTE_Mask_Type { ...@@ -146,6 +146,8 @@ enum NVTE_Mask_Type {
NVTE_PADDING_MASK = 1, NVTE_PADDING_MASK = 1,
/*! Causal attention mask */ /*! Causal attention mask */
NVTE_CAUSAL_MASK = 2, NVTE_CAUSAL_MASK = 2,
/*! Padding and causal attention mask */
NVTE_PADDING_CAUSAL_MASK = 3,
}; };
/*! \enum NVTE_Fused_Attn_Backend /*! \enum NVTE_Fused_Attn_Backend
...@@ -209,10 +211,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -209,10 +211,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 | | 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 | | 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | | 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] QKV The QKV tensor in packed format, * \param[in] QKV The QKV tensor in packed format,
...@@ -254,10 +256,10 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -254,10 +256,10 @@ void nvte_fused_attn_fwd_qkvpacked(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 | | 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 | | 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | | 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] QKV The QKV tensor in packed format, * \param[in] QKV The QKV tensor in packed format,
...@@ -308,8 +310,8 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -308,8 +310,8 @@ void nvte_fused_attn_bwd_qkvpacked(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 | | 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
...@@ -356,8 +358,8 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -356,8 +358,8 @@ void nvte_fused_attn_fwd_kvpacked(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 | | 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
...@@ -415,10 +417,10 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -415,10 +417,10 @@ void nvte_fused_attn_bwd_kvpacked(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL_MASK | Yes | <= 512 | 64 | | 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 | | 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | | 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] Q The Q tensor. * \param[in] Q The Q tensor.
...@@ -467,10 +469,10 @@ void nvte_fused_attn_fwd( ...@@ -467,10 +469,10 @@ void nvte_fused_attn_fwd(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL_MASK | Yes | <= 512 | 64 | | 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 | | 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | | 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] Q The Q tensor. * \param[in] Q The Q tensor.
......
...@@ -83,7 +83,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -83,7 +83,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
pybind11::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type", pybind11::module_local()) pybind11::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type", pybind11::module_local())
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK); .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK)
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK);
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local()) pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local())
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
......
...@@ -102,7 +102,7 @@ def _combine_biases(*masks: List[Array]): ...@@ -102,7 +102,7 @@ def _combine_biases(*masks: List[Array]):
return mask return mask
class Softmax(nn.Module): class Softmax(nn.Module): # pylint: disable=too-few-public-methods
r""" r"""
Applies softmax over a mini-batch of inputs. Applies softmax over a mini-batch of inputs.
The input's shape should be [batch, heads, q_seqlen, k_seqlen]. The input's shape should be [batch, heads, q_seqlen, k_seqlen].
...@@ -176,7 +176,7 @@ class Softmax(nn.Module): ...@@ -176,7 +176,7 @@ class Softmax(nn.Module):
return outputs return outputs
class LayerNorm(nn.Module): class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
r""" r"""
Applies layer normalization over a mini-batch of inputs. Applies layer normalization over a mini-batch of inputs.
There are two types of normalization supported by this module, There are two types of normalization supported by this module,
...@@ -431,8 +431,9 @@ class DenseGeneral(TransformerEngineBase): ...@@ -431,8 +431,9 @@ class DenseGeneral(TransformerEngineBase):
bias = nn_partitioning.param_with_axes('bias', bias = nn_partitioning.param_with_axes('bias',
self.bias_init, self.bias_init,
features, features,
self.dtype, jnp.float32,
axes=self.bias_axes) axes=self.bias_axes)
bias = bias.astype(self.dtype)
else: else:
bias = None bias = None
...@@ -656,8 +657,9 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -656,8 +657,9 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias = nn_partitioning.param_with_axes('bias', bias = nn_partitioning.param_with_axes('bias',
self.bias_init, self.bias_init,
features, features,
self.dtype, jnp.float32,
axes=self.bias_axes) axes=self.bias_axes)
bias = bias.astype(self.dtype)
if bias is not None: if bias is not None:
bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
...@@ -969,8 +971,9 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -969,8 +971,9 @@ class LayerNormMLP(TransformerEngineBase):
bias = nn_partitioning.param_with_axes('wi_bias', bias = nn_partitioning.param_with_axes('wi_bias',
self.bias_init, self.bias_init,
intermediate_dim, intermediate_dim,
self.dtype, jnp.float32,
axes=self.bias_axes_1) axes=self.bias_axes_1)
bias = bias.astype(self.dtype)
bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape
x += jnp.reshape(bias, bias_shape) x += jnp.reshape(bias, bias_shape)
...@@ -1029,8 +1032,9 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1029,8 +1032,9 @@ class LayerNormMLP(TransformerEngineBase):
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes('wo_bias', bias = nn_partitioning.param_with_axes('wo_bias',
self.bias_init, (hidden_size,), self.bias_init, (hidden_size,),
self.dtype, jnp.float32,
axes=self.bias_axes_2) axes=self.bias_axes_2)
bias = bias.astype(self.dtype)
out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,)) out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,))
return out, ln_output # Output, layner_norm_output return out, ln_output # Output, layner_norm_output
...@@ -247,7 +247,7 @@ def core_attention(query: Array, ...@@ -247,7 +247,7 @@ def core_attention(query: Array,
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
r""" r"""
Multi-head Attention (MHA), including Query, Multi-head Attention (MHA), including Query,
Key, Value and Output projection. Key, Value and Output projection.
...@@ -422,7 +422,7 @@ class MultiHeadAttention(nn.Module): ...@@ -422,7 +422,7 @@ class MultiHeadAttention(nn.Module):
Convert the string to AttnMaskType Convert the string to AttnMaskType
""" """
if attn_mask_type == 'causal': if attn_mask_type == 'causal':
return AttnMaskType.CAUSAL_MASK return AttnMaskType.PADDING_CAUSAL_MASK
if attn_mask_type == 'padding': if attn_mask_type == 'padding':
return AttnMaskType.PADDING_MASK return AttnMaskType.PADDING_MASK
raise ValueError(f"Unsupported {attn_mask_type=}, " raise ValueError(f"Unsupported {attn_mask_type=}, "
...@@ -741,7 +741,7 @@ class MultiHeadAttention(nn.Module): ...@@ -741,7 +741,7 @@ class MultiHeadAttention(nn.Module):
return out, residual return out, residual
class RelativePositionBiases(nn.Module): class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-methods
""" """
T5-style relative positional embeddings to the attention logits. T5-style relative positional embeddings to the attention logits.
...@@ -848,7 +848,7 @@ class TransformerLayerType(Enum): ...@@ -848,7 +848,7 @@ class TransformerLayerType(Enum):
DECODER = "decoder" DECODER = "decoder"
class TransformerLayer(nn.Module): class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
r""" r"""
TransformerLayer is made up of a relative embedding, TransformerLayer is made up of a relative embedding,
an attention block and a feedforward network (MLP). an attention block and a feedforward network (MLP).
......
...@@ -35,6 +35,7 @@ class AttnMaskType(Enum): ...@@ -35,6 +35,7 @@ class AttnMaskType(Enum):
NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK
class QKVLayout(Enum): class QKVLayout(Enum):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment