Unverified Commit 69003969 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

Jax bug fixes for the dot product attention (#236)



* Unfused scale+softmax if bias is present
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* WAR a causal masking + no_bias bug and add the unittests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix the optional args (bias) sharding
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Disable fused attn in JAX by default, enable it with NVTE_USE_FUSED_ATTN
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add thread local for the plan cache
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Rename dbeta to dbias for the readability
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add scaled softmax with dropout test cases
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Updated NVTE_FUSED_ATTN variable name
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 122de2cc
...@@ -113,7 +113,7 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs) ...@@ -113,7 +113,7 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)
reason="Fused attention kernel is not supported.") reason="Fused attention kernel is not supported.")
class TestSelfFusedAttnMax512(): class TestSelfFusedAttnMax512():
def set_input(self, b, s, h, d, dtype, attn_mask_type, pad_ratio): def set_input(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
...@@ -125,7 +125,8 @@ class TestSelfFusedAttnMax512(): ...@@ -125,7 +125,8 @@ class TestSelfFusedAttnMax512():
min_val, max_val = -1, 1 min_val, max_val = -1, 1
self.qkv = jax.random.uniform(subkeys[0], qkv_shape, dtype, min_val, max_val) self.qkv = jax.random.uniform(subkeys[0], qkv_shape, dtype, min_val, max_val)
self.bias = jax.random.uniform(subkeys[1], bias_shape, dtype, min_val, max_val) self.bias = jax.random.uniform(subkeys[1], bias_shape, dtype, min_val,
max_val) if with_bias else None
self.q_token = jnp.concatenate((jnp.ones((b, self.valid_len)), jnp.zeros((b, pad_len))), self.q_token = jnp.concatenate((jnp.ones((b, self.valid_len)), jnp.zeros((b, pad_len))),
axis=-1) axis=-1)
...@@ -133,8 +134,8 @@ class TestSelfFusedAttnMax512(): ...@@ -133,8 +134,8 @@ class TestSelfFusedAttnMax512():
self.scaling_factor = 1. / math.sqrt(d) self.scaling_factor = 1. / math.sqrt(d)
self.dropout_probability = 0. self.dropout_probability = 0.
self.dropout_rng = jax.random.PRNGKey(0) self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None
self.attn_bias_type = AttnBiasType.POST_SCALE_BIAS self.attn_bias_type = AttnBiasType.NO_BIAS if self.bias is None else AttnBiasType.POST_SCALE_BIAS
# deterministic = not is_training # deterministic = not is_training
self.deterministic = False self.deterministic = False
...@@ -143,9 +144,17 @@ class TestSelfFusedAttnMax512(): ...@@ -143,9 +144,17 @@ class TestSelfFusedAttnMax512():
@pytest.mark.parametrize('attn_mask_type', @pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('pad_ratio', PAD_RATIO) @pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio): @pytest.mark.parametrize('with_bias', [True, False])
def test_forward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias):
self.set_input(b, s, h, d, dtype=dtype, attn_mask_type=attn_mask_type, pad_ratio=pad_ratio) self.set_input(b,
s,
h,
d,
dtype=dtype,
attn_mask_type=attn_mask_type,
pad_ratio=pad_ratio,
with_bias=with_bias)
primitive_out = customcall_self_fused_attn(self.qkv, primitive_out = customcall_self_fused_attn(self.qkv,
self.bias, self.bias,
...@@ -183,8 +192,16 @@ class TestSelfFusedAttnMax512(): ...@@ -183,8 +192,16 @@ class TestSelfFusedAttnMax512():
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('pad_ratio', PAD_RATIO) @pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward_backward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio): @pytest.mark.parametrize('with_bias', [True, False])
self.set_input(b, s, h, d, dtype=dtype, attn_mask_type=attn_mask_type, pad_ratio=pad_ratio) def test_forward_backward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias):
self.set_input(b,
s,
h,
d,
dtype=dtype,
attn_mask_type=attn_mask_type,
pad_ratio=pad_ratio,
with_bias=with_bias)
def grad_func(fused_attn_max_512_func, *args, **kwargs): def grad_func(fused_attn_max_512_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident # Gradient is small, use a gradient multiplier to amplify the graident
...@@ -221,11 +238,11 @@ class TestSelfFusedAttnMax512(): ...@@ -221,11 +238,11 @@ class TestSelfFusedAttnMax512():
(0, 1))) (0, 1)))
primitive_out, (primitive_dqkv, primitive_out, (primitive_dqkv,
primitive_dbeta) = jitted_primitive(self.qkv, self.bias, self.q_token, primitive_dbias) = jitted_primitive(self.qkv, self.bias, self.q_token,
self.kv_token, self.dropout_rng) self.kv_token, self.dropout_rng)
reference_out, (reference_dqkv, reference_out, (reference_dqkv,
reference_dbeta) = jitted_reference(self.qkv, self.bias, self.q_token, reference_dbias) = jitted_reference(self.qkv, self.bias, self.q_token,
self.kv_token, self.dropout_rng) self.kv_token, self.dropout_rng)
np.testing.assert_allclose(jnp.asarray(primitive_out, np.float32), np.testing.assert_allclose(jnp.asarray(primitive_out, np.float32),
...@@ -261,20 +278,22 @@ class TestSelfFusedAttnMax512(): ...@@ -261,20 +278,22 @@ class TestSelfFusedAttnMax512():
# Padded part should be 0s # Padded part should be 0s
assert jnp.allclose(invalid_primitive_dqkv, jnp.zeros_like(invalid_primitive_dqkv)) assert jnp.allclose(invalid_primitive_dqkv, jnp.zeros_like(invalid_primitive_dqkv))
# dbeta valid part if self.attn_bias_type != AttnBiasType.NO_BIAS:
# dbias valid part
np.testing.assert_allclose( np.testing.assert_allclose(
jnp.asarray(primitive_dbeta[:, :, :self.valid_len, :self.valid_len], np.float32), jnp.asarray(primitive_dbias[:, :, :self.valid_len, :self.valid_len], np.float32),
jnp.asarray(reference_dbeta[:, :, :self.valid_len, :self.valid_len], np.float32), jnp.asarray(reference_dbias[:, :, :self.valid_len, :self.valid_len], np.float32),
rtol=1e-4, rtol=1e-4,
atol=3e-5) atol=3e-5)
# dbeta padded part # dbias padded part
np.testing.assert_allclose( np.testing.assert_allclose(
jnp.asarray(primitive_dbeta[:, :, self.valid_len:, self.valid_len:], np.float32), jnp.asarray(primitive_dbias[:, :, self.valid_len:, self.valid_len:], np.float32),
jnp.asarray(reference_dbeta[:, :, self.valid_len:, self.valid_len:], np.float32)) jnp.asarray(reference_dbias[:, :, self.valid_len:, self.valid_len:], np.float32))
assert jnp.allclose(primitive_dbeta[:, :, self.valid_len:, self.valid_len:], assert jnp.allclose(
jnp.zeros_like(primitive_dbeta[:, :, self.valid_len:, self.valid_len:])) primitive_dbias[:, :, self.valid_len:, self.valid_len:],
jnp.zeros_like(primitive_dbias[:, :, self.valid_len:, self.valid_len:]))
@pytest.mark.skipif(not is_fused_attn_kernel_available(), @pytest.mark.skipif(not is_fused_attn_kernel_available(),
......
...@@ -102,6 +102,12 @@ ATTRS = [{ ...@@ -102,6 +102,12 @@ ATTRS = [{
_KEY_OF_DROPOUT_RATE: 0.0, _KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')), _KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True _KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_DROPOUT_RATE: 0.8,
_KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True
}, { }, {
_KEY_OF_TRANSPOSE_BS: False, _KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True, _KEY_OF_SCALE_ATTN_LOGITS: True,
......
...@@ -327,7 +327,6 @@ static cudnn_frontend::Tensor createSoftmaxForward( ...@@ -327,7 +327,6 @@ static cudnn_frontend::Tensor createSoftmaxForward(
// NOLINTNEXTLINE(runtime/references) // NOLINTNEXTLINE(runtime/references)
std::vector<cudnn_frontend::Operation> &ops, std::vector<cudnn_frontend::Operation> &ops,
cudnn_frontend::Tensor const &prevBlockOutputTensor) { cudnn_frontend::Tensor const &prevBlockOutputTensor) {
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
...@@ -645,7 +644,7 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -645,7 +644,7 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
mask_type, tensorType}; mask_type, tensorType};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>; using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static CacheType fmha_fprop_cache; static thread_local CacheType fmha_fprop_cache;
bool enable_dropout = (dropout_probability != 0.0f); bool enable_dropout = (dropout_probability != 0.0f);
...@@ -668,7 +667,8 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -668,7 +667,8 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
createScale(b, h, s_q, s_kv, d, layout, tensorType, ops); createScale(b, h, s_q, s_kv, d, layout, tensorType, ops);
// if bias, we need to memset the S buffer to correctly computate dbias // if bias, we need to memset the S buffer to correctly computate dbias
auto zero_s = (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS); auto zero_s = (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) ||
(mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK);
auto bmm1_output = createBMM1(b, h, s_q, s_kv, d, layout, tensorType, zero_s, ops); auto bmm1_output = createBMM1(b, h, s_q, s_kv, d, layout, tensorType, zero_s, ops);
NVTE_CHECK(bias_type != NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS, NVTE_CHECK(bias_type != NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS,
...@@ -814,7 +814,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -814,7 +814,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
layout, bias_type, mask_type, tensorType}; layout, bias_type, mask_type, tensorType};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>; using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static CacheType fmha_bprop_cache; static thread_local CacheType fmha_bprop_cache;
auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) {
auto it = cache.find(descriptor); auto it = cache.find(descriptor);
......
...@@ -1016,7 +1016,7 @@ void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, ...@@ -1016,7 +1016,7 @@ void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
NVTE_Bias_Type::NVTE_NO_BIAS, NVTE_Mask_Type::NVTE_PADDING_MASK, tensorType}; NVTE_Bias_Type::NVTE_NO_BIAS, NVTE_Mask_Type::NVTE_PADDING_MASK, tensorType};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>; using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static CacheType fa_fprop_cache; static thread_local CacheType fa_fprop_cache;
// Get plan from cache if cache is available, otherwise create one // Get plan from cache if cache is available, otherwise create one
auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) {
...@@ -1332,7 +1332,7 @@ void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, ...@@ -1332,7 +1332,7 @@ void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
NVTE_Bias_Type::NVTE_NO_BIAS, NVTE_Mask_Type::NVTE_PADDING_MASK, tensorType}; NVTE_Bias_Type::NVTE_NO_BIAS, NVTE_Mask_Type::NVTE_PADDING_MASK, tensorType};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>; using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static CacheType fa_bprop_cache; static thread_local CacheType fa_bprop_cache;
// Get plan from cache if cache is available, otherwise create one // Get plan from cache if cache is available, otherwise create one
auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) {
......
...@@ -7,6 +7,7 @@ Wrapper module for Transformer related layers with FP8 support. ...@@ -7,6 +7,7 @@ Wrapper module for Transformer related layers with FP8 support.
import functools import functools
from enum import Enum from enum import Enum
from math import sqrt from math import sqrt
import os
from typing import Any, Callable, Optional, Sequence, Tuple, Union from typing import Any, Callable, Optional, Sequence, Tuple, Union
import warnings import warnings
...@@ -165,8 +166,17 @@ def core_attention(query: Array, ...@@ -165,8 +166,17 @@ def core_attention(query: Array,
else: else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
# When a bias is present, the computation is performed as Softmax(attn_weights * scale + bias).
# In this case, the scale can not fused into the Softmax module.
if bias is not None:
attn_weights = attn_weights * scale_factor
fused_scale_factor = 1.
else:
# If no bias, the scale can be fused into Softmax module
fused_scale_factor = scale_factor
attn_weights = Softmax(softmax_type=softmax_type, attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=scale_factor, scale_factor=fused_scale_factor,
sharding_type=softmax_sharding_type)(attn_weights, mask, bias) sharding_type=softmax_sharding_type)(attn_weights, mask, bias)
if not deterministic and dropout_rate > 0.: if not deterministic and dropout_rate > 0.:
...@@ -360,12 +370,13 @@ class MultiHeadAttention(nn.Module): ...@@ -360,12 +370,13 @@ class MultiHeadAttention(nn.Module):
q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1] q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1]
kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1] kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1]
fused_attn_supported_seqlen = [128, 256, 384, 512] fused_attn_supported_seqlen = [128, 256, 384, 512]
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))
use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \ use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
self.dropout_rate == 0 and canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \ self.dropout_rate == 0 and canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \
q_seqlen in fused_attn_supported_seqlen and kv_seqlen in fused_attn_supported_seqlen \ q_seqlen in fused_attn_supported_seqlen and kv_seqlen in fused_attn_supported_seqlen \
and is_fused_attn_kernel_available() and is_fused_attn_kernel_available() and enable_fused_attn
if not use_fused_attn: if enable_fused_attn and not use_fused_attn:
reason = "" reason = ""
if decode: if decode:
reason += f"decode=False is required but got {decode}, " reason += f"decode=False is required but got {decode}, "
......
...@@ -386,7 +386,7 @@ class FusedAttnShardingMetaGenerator(ShardingMetaGenerator): ...@@ -386,7 +386,7 @@ class FusedAttnShardingMetaGenerator(ShardingMetaGenerator):
for input_shape, dp_dim, tp_dim in zip(input_shapes, input_dp_dims, input_tp_dims): for input_shape, dp_dim, tp_dim in zip(input_shapes, input_dp_dims, input_tp_dims):
in_axis = {} in_axis = {}
if dp_dim is not None: if dp_dim is not None and input_shape is not None:
in_axis[dp_dim] = dp_axis_name in_axis[dp_dim] = dp_axis_name
assert input_shape[dp_dim] % dp_size == 0, \ assert input_shape[dp_dim] % dp_size == 0, \
f"The dimension of batch in input_shape should be a multiple of " \ f"The dimension of batch in input_shape should be a multiple of " \
...@@ -398,7 +398,7 @@ class FusedAttnShardingMetaGenerator(ShardingMetaGenerator): ...@@ -398,7 +398,7 @@ class FusedAttnShardingMetaGenerator(ShardingMetaGenerator):
if tp_dim is not None and tp_dim >= dp_dim: if tp_dim is not None and tp_dim >= dp_dim:
tp_dim = tp_dim + 1 tp_dim = tp_dim + 1
if tp_dim is not None: if tp_dim is not None and input_shape is not None:
in_axis[tp_dim] = tp_axis_name in_axis[tp_dim] = tp_axis_name
assert input_shape[tp_dim] % tp_size == 0, \ assert input_shape[tp_dim] % tp_size == 0, \
f"The dimension of tensor parallel in input_shape should be a multiple of " \ f"The dimension of tensor parallel in input_shape should be a multiple of " \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment