Unverified Commit bfaec644 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[C/JAX] Support more mask types for the arbitrary seqlen kernels and minor...


[C/JAX] Support more mask types for the arbitrary seqlen kernels and minor changes of JAX bias (#469)

* Move bias to float32
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enable varlen
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Increase neg infinity abs values
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enable varlen tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove unnecessary code
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix lint
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Support variable sequence length after cuDNN 8.9.6
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use unique_ptr instead of shared_ptr
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add a new mask type: PADDING_CAUSAL_MASK
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Support flash padding mask after 8.9.6
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance the Max512 handling for causal masking and add the related tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update the fused attn support lists
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove padding_aware from the caching
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix libtransformer.so issue
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Reduce the pad ratio tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix a bug with cuDNN 8.9.5
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Release backend resource after the module level unit test
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Clean the jax live arrays before running the unit tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix too-few-public-methods lint
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 64a3d1d5
......@@ -37,6 +37,16 @@ DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, reason = is_fp8_available()
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()
class TestFP8Dot:
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
......
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
import pytest
import jax
import jax.numpy as jnp
from jax.core import ShapedArray
......@@ -31,6 +32,16 @@ DTYPE = [DType.kFloat32, DType.kFloat16, DType.kBFloat16]
TRANSPOSE = [True, False]
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()
class TestGEMMShapeInfer:
@staticmethod
......
......@@ -21,12 +21,22 @@ from jax import value_and_grad, jit
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine_jax import get_device_compute_capability
from transformer_engine_jax import get_device_compute_capability # pylint: disable=wrong-import-order
# Type annotations
Array = jnp.ndarray
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()
class Backend(Enum):
"""
Fused attn backend.
......@@ -52,6 +62,13 @@ CROSS_CASES = [(32, 128, 512, 16, 64)]
DTYPES = [jnp.bfloat16, jnp.float16]
def is_causal_mask(mask: AttnMaskType):
"""
Check if the mask is a causal mask
"""
return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]
def make_decoder_mask(tokens: Array) -> Array:
"""
Create padded causal mask
......@@ -66,7 +83,7 @@ def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
Self attention with JAX native implementation
"""
attn_mask_type = kwargs['attn_mask_type']
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
if is_causal_mask(attn_mask_type):
mask = make_decoder_mask(q_token)
else:
mask = make_attention_mask(q_token > 0, kv_token > 0)
......@@ -84,8 +101,8 @@ def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
deterministic=not kwargs['is_training'],
dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng,
dtype=qkv.dtype)
return output
dtype=jnp.float32)
return output.astype(qkv.dtype)
def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
......@@ -95,7 +112,7 @@ def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
assert q.dtype == kv.dtype
attn_mask_type = kwargs['attn_mask_type']
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
if is_causal_mask(attn_mask_type):
raise NotImplementedError
mask = make_attention_mask(q_token > 0, kv_token > 0)
......@@ -112,15 +129,16 @@ def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
deterministic=not kwargs['is_training'],
dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng,
dtype=q.dtype)
return output
dtype=jnp.float32)
return output.astype(q.dtype)
def customcall_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
"""
Self fused attention
"""
if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK:
attn_mask_type = kwargs['attn_mask_type']
if is_causal_mask(attn_mask_type):
mask = make_decoder_mask(q_token)
else:
mask = make_attention_mask(q_token > 0, kv_token > 0)
......@@ -137,7 +155,8 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)
"""
assert q.dtype == kv.dtype
if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK:
attn_mask_type = kwargs['attn_mask_type']
if is_causal_mask(attn_mask_type):
raise NotImplementedError
mask = make_attention_mask(q_token > 0, kv_token > 0)
......@@ -149,32 +168,28 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)
@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('attn_mask_type', [
AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK,
AttnMaskType.PADDING_CAUSAL_MASK
])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', [0, 0.3])
class TestSelfFusedAttn():
"""Tests for transformer_engine.jax.fused_attn.self_fused_attn"""
@staticmethod
def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probability, dtype,
head_dim, pad_ratio):
if (s > 512 or backend == Backend.Arbitrary) and pad_ratio != 0:
pytest.skip("Arbitrary seqlen backend hasn't support padded input.")
head_dim):
assert isinstance(backend, Backend)
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
attn_mask_type, dropout_probability, s, s, head_dim):
pytest.skip("Unsupported inputs combination or device compute capability.")
compute_capability = get_device_compute_capability(0)
if (backend == Backend.Max512
and not (compute_capability == 80 or compute_capability >= 90)):
pytest.skip("Unsupported compute capability for "
"fused attention with <=512 sequence length")
def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
dropout_probability, dtype, is_training, pad_ratio):
dropout_probability, dtype, is_training):
"""Setup the test inputs"""
self.__class__._check_inputs(s,
attn_bias_type=attn_bias_type,
......@@ -182,8 +197,13 @@ class TestSelfFusedAttn():
backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
head_dim=d,
pad_ratio=pad_ratio)
head_dim=d)
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pad_ratio = 0.0
else:
pad_ratio = 0.3
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
......@@ -212,7 +232,7 @@ class TestSelfFusedAttn():
self.is_training = is_training
def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, dropout_probability,
dtype, is_training, pad_ratio):
dtype, is_training):
"""
Test forward without using JIT
"""
......@@ -225,8 +245,7 @@ class TestSelfFusedAttn():
backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
is_training=is_training)
primitive_out = customcall_self_fused_attn(self.qkv,
self.bias,
......@@ -265,7 +284,7 @@ class TestSelfFusedAttn():
jnp.zeros_like(pri_invalid, jnp.float32))
def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend,
dropout_probability, dtype, is_training, pad_ratio):
dropout_probability, dtype, is_training):
"""
Test forward, backward, and autodiff by jax.value_and_grad
"""
......@@ -281,13 +300,12 @@ class TestSelfFusedAttn():
backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
is_training=is_training)
def grad_func(fused_attn_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident
gradient_multiplier = 1000 if dtype == jnp.bfloat16 else 10000
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
if is_causal_mask(attn_mask_type):
gradient_multiplier = gradient_multiplier / 10
# Keep only valid result for the gradient
# fused_attn output has shape (b, s, h, d)
......@@ -333,15 +351,15 @@ class TestSelfFusedAttn():
rtol=1e-4,
atol=1e-5)
valid_primitive_dqkv, invalid_primitive_dqkv = jnp.split(primitive_dqkv, (self.valid_len,),
axis=1)
valid_reference_dqkv, invalid_reference_dqkv = jnp.split(reference_dqkv, (self.valid_len,),
axis=1)
valid_primitive_dqkv, invalid_primitive_dqkv = \
jnp.split(primitive_dqkv.astype(jnp.float32), (self.valid_len,), axis=1)
valid_reference_dqkv, invalid_reference_dqkv = \
jnp.split(reference_dqkv.astype(jnp.float32), (self.valid_len,), axis=1)
valid_primitive_dq, valid_primitive_dk, valid_primitive_dv = jnp.split(
valid_primitive_dqkv.astype(jnp.float32), 3, axis=2)
valid_reference_dq, valid_reference_dk, valid_reference_dv = jnp.split(
valid_reference_dqkv.astype(jnp.float32), 3, axis=2)
valid_primitive_dq, valid_primitive_dk, valid_primitive_dv = \
jnp.split(valid_primitive_dqkv, 3, axis=2)
valid_reference_dq, valid_reference_dk, valid_reference_dv = \
jnp.split(valid_reference_dqkv, 3, axis=2)
np.testing.assert_allclose(valid_primitive_dq, valid_reference_dq, rtol=1e-4, atol=1e-5)
np.testing.assert_allclose(valid_primitive_dk, valid_reference_dk, rtol=1e-4, atol=1e-5)
......@@ -482,9 +500,7 @@ class TestCrossFusedAttn():
def grad_func(fused_attn_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident
gradient_multiplier = 10000
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
gradient_multiplier = gradient_multiplier / 10
gradient_multiplier = 1e4
# Keep only valid result for the gradient
# fused_attn output has shape (b, s_q, h, d)
valid_fused_attn_ret, _ = jnp.split(fused_attn_func(*args, **kwargs),
......
......@@ -19,6 +19,16 @@ from utils import EncoderLayer as RefEncoderLayer
is_fp8_supported, reason = is_fp8_available()
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()
def loss_fn(diff_xs, no_diff_xs, params, others, model, rngs):
output = model.apply({"params": params, **others}, *diff_xs, *no_diff_xs, rngs=rngs)
return jnp.mean(output)
......
......@@ -38,6 +38,16 @@ ENABLE_FP8 = [False, True]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
for key in ref_fd:
assert key in test_fd, \
......
......@@ -87,6 +87,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
const int sm_arch_ = cuda::sm_arch(device_id);
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
auto cudnn_runtime_version = cudnnGetVersion();
if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)
&& (sm_arch_ >= 90)
&& (max_seqlen_q == max_seqlen_kv)
......@@ -111,6 +112,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
&& ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
|| (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS))
&& ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
......@@ -131,7 +133,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
&& (max_seqlen_q == max_seqlen_kv)
&& ((head_dim == 64) || (head_dim == 128))
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
&& ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| ((cudnn_runtime_version >= 8906) &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)))
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD))) {
......
......@@ -298,7 +298,8 @@ static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int6
/////////////////// Apply the mask //////////////////////////
auto maskTensor = (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
auto maskTensor = (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
? std::move(causalMaskTensor)
: std::move(paddingMaskTensor);
......@@ -314,7 +315,8 @@ static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int6
ops.push_back(std::move(lessThanRow_op));
ops.push_back(std::move(lessThanCol_op));
ops.push_back(std::move(paddingMaskAnd_op));
if (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) {
if (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) {
ops.push_back(std::move(rowGreaterCol_op));
ops.push_back(std::move(causalMaskAnd_op));
}
......@@ -680,7 +682,8 @@ void fused_attn_max_512_fwd_impl(
// WAR: causal_mask without bias needs memset the S buffer
// inference mode doesn't need the S auxiliary
auto zero_s = (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) ||
(mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) && is_training;
(mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) && is_training;
std::shared_ptr<cudnn_frontend::Tensor> maskInput;
auto bmm1_output = createBMM1(b, h, s_q, s_kv, d, layout, tensorType, zero_s, ops);
......
......@@ -146,6 +146,8 @@ enum NVTE_Mask_Type {
NVTE_PADDING_MASK = 1,
/*! Causal attention mask */
NVTE_CAUSAL_MASK = 2,
/*! Padding and causal attention mask */
NVTE_PADDING_CAUSAL_MASK = 3,
};
/*! \enum NVTE_Fused_Attn_Backend
......@@ -209,10 +211,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
*
* Support Matrix:
\verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] QKV The QKV tensor in packed format,
......@@ -254,10 +256,10 @@ void nvte_fused_attn_fwd_qkvpacked(
*
* Support Matrix:
\verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] QKV The QKV tensor in packed format,
......@@ -308,8 +310,8 @@ void nvte_fused_attn_bwd_qkvpacked(
*
* Support Matrix:
\verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
......@@ -356,8 +358,8 @@ void nvte_fused_attn_fwd_kvpacked(
*
* Support Matrix:
\verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
......@@ -415,10 +417,10 @@ void nvte_fused_attn_bwd_kvpacked(
*
* Support Matrix:
\verbatim
| backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL_MASK | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 |
| 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
| backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor.
......@@ -467,10 +469,10 @@ void nvte_fused_attn_fwd(
*
* Support Matrix:
\verbatim
| backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL_MASK | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 |
| 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
| backend | precision | qkv format | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | NO_MASK/PADDING/CAUSAL/PADDING_CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | SBHD, BSHD | NO/POST_SCALE_BIAS | PADDING/CAUSAL/PADDING_CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | THD | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor.
......
......@@ -83,7 +83,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
pybind11::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type", pybind11::module_local())
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK);
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK)
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK);
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local())
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
......
......@@ -102,7 +102,7 @@ def _combine_biases(*masks: List[Array]):
return mask
class Softmax(nn.Module):
class Softmax(nn.Module): # pylint: disable=too-few-public-methods
r"""
Applies softmax over a mini-batch of inputs.
The input's shape should be [batch, heads, q_seqlen, k_seqlen].
......@@ -176,7 +176,7 @@ class Softmax(nn.Module):
return outputs
class LayerNorm(nn.Module):
class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
r"""
Applies layer normalization over a mini-batch of inputs.
There are two types of normalization supported by this module,
......@@ -431,8 +431,9 @@ class DenseGeneral(TransformerEngineBase):
bias = nn_partitioning.param_with_axes('bias',
self.bias_init,
features,
self.dtype,
jnp.float32,
axes=self.bias_axes)
bias = bias.astype(self.dtype)
else:
bias = None
......@@ -656,8 +657,9 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias = nn_partitioning.param_with_axes('bias',
self.bias_init,
features,
self.dtype,
jnp.float32,
axes=self.bias_axes)
bias = bias.astype(self.dtype)
if bias is not None:
bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
......@@ -969,8 +971,9 @@ class LayerNormMLP(TransformerEngineBase):
bias = nn_partitioning.param_with_axes('wi_bias',
self.bias_init,
intermediate_dim,
self.dtype,
jnp.float32,
axes=self.bias_axes_1)
bias = bias.astype(self.dtype)
bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape
x += jnp.reshape(bias, bias_shape)
......@@ -1029,8 +1032,9 @@ class LayerNormMLP(TransformerEngineBase):
if self.use_bias:
bias = nn_partitioning.param_with_axes('wo_bias',
self.bias_init, (hidden_size,),
self.dtype,
jnp.float32,
axes=self.bias_axes_2)
bias = bias.astype(self.dtype)
out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,))
return out, ln_output # Output, layner_norm_output
......@@ -247,7 +247,7 @@ def core_attention(query: Array,
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
class MultiHeadAttention(nn.Module):
class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
r"""
Multi-head Attention (MHA), including Query,
Key, Value and Output projection.
......@@ -422,7 +422,7 @@ class MultiHeadAttention(nn.Module):
Convert the string to AttnMaskType
"""
if attn_mask_type == 'causal':
return AttnMaskType.CAUSAL_MASK
return AttnMaskType.PADDING_CAUSAL_MASK
if attn_mask_type == 'padding':
return AttnMaskType.PADDING_MASK
raise ValueError(f"Unsupported {attn_mask_type=}, "
......@@ -741,7 +741,7 @@ class MultiHeadAttention(nn.Module):
return out, residual
class RelativePositionBiases(nn.Module):
class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-methods
"""
T5-style relative positional embeddings to the attention logits.
......@@ -848,7 +848,7 @@ class TransformerLayerType(Enum):
DECODER = "decoder"
class TransformerLayer(nn.Module):
class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
r"""
TransformerLayer is made up of a relative embedding,
an attention block and a feedforward network (MLP).
......
......@@ -35,6 +35,7 @@ class AttnMaskType(Enum):
NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK
class QKVLayout(Enum):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment