Unverified Commit b8eea8aa authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[C/PyTorch/Jax] Add support for more bias shapes (#677)



* added support for arbitrary bias shapes for fused_attn
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* Fix linting
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* Add b1ss/bhss/11ss bias shapes when not requiring dBias
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add bias_b/h to plan cache
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fixed compile errors after PR653 merge
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* updated JAX unittests for new bias shapes
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed mismatched mask type checking
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected skip condition
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fix selection logic for A100s
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* corrected skip checks for bias shapes
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* resolved test issues but neginf with float16 is still problematic with JAX
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* new bias shapes passing TE JAX CI for seqlen <= 512, seq_q == seq_kv and h_q == h_kv conditions
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* TE/JAX fused attn tests for new bias shapes passing with neg_inf=-2**27 for Bfloat16 and -2**15 for Float16
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* code style fixes and test parameter ID cleanup
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed incorrect skip condition for backward fused attn test
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarAlp Dener <adener@nvidia.com>
parent 04040957
...@@ -2,14 +2,15 @@ ...@@ -2,14 +2,15 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Tests for fused attention""" """Tests for fused attention"""
import sys
from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from math import sqrt from math import sqrt
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
import pytest import pytest
from flax.linen import combine_masks from flax.linen import combine_masks
...@@ -21,7 +22,11 @@ from jax.typing import ArrayLike, DTypeLike ...@@ -21,7 +22,11 @@ from jax.typing import ArrayLike, DTypeLike
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn, fused_attn from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn, fused_attn
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine_jax import NVTE_Fused_Attn_Backend
from utils import assert_allclose
@pytest.fixture(autouse=True, scope='function') @pytest.fixture(autouse=True, scope='function')
...@@ -38,7 +43,7 @@ def clear_live_arrays(): ...@@ -38,7 +43,7 @@ def clear_live_arrays():
def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: ArrayLike, def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: ArrayLike,
bias: ArrayLike, mask: ArrayLike, deterministic: bool, bias: ArrayLike, mask: ArrayLike, deterministic: bool,
dropout_rate: float, dropout_rng: ArrayLike, scale_factor: float, dropout_rate: float, dropout_rng: ArrayLike,
dtype: DTypeLike) -> Array: dtype: DTypeLike) -> Array:
""" """
Similar to flax.linen.dot_product_attention but with GQA support Similar to flax.linen.dot_product_attention but with GQA support
...@@ -46,21 +51,21 @@ def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: Array ...@@ -46,21 +51,21 @@ def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: Array
query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype) query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
dtype = query.dtype dtype = query.dtype
depth = query.shape[-1]
query = query / jnp.sqrt(depth).astype(dtype)
b, s_q, h_q, d = query.shape b, s_q, h_q, d = query.shape
_, _, h_kv, _ = key.shape _, s_kv, h_kv, _ = key.shape
assert (h_q % h_kv == 0) and (h_q >= h_kv) assert (h_q % h_kv == 0) and (h_q >= h_kv)
num_groups = h_q // h_kv num_groups = h_q // h_kv
grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d)) grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
# logits with shape (b, h_kv, num_groups, s_q, s_kv) # logits with shape (b, h_kv, num_groups, s_q, s_kv)
logits = jnp.einsum('...qhgd,...khd->...hgqk', grouped_query, key) logits = scale_factor * jnp.einsum('...qhgd,...khd->...hgqk', grouped_query, key)
if bias is not None: if bias is not None:
if bias.ndim != logits.ndim: # reshape logits without groups
bias = bias.reshape((1, *logits.shape[1:])) logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
# apply post-scale bias
logits = logits + bias logits = logits + bias
# reshape logits back to original
logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
if mask is not None: if mask is not None:
if mask.ndim != logits.ndim: if mask.ndim != logits.ndim:
...@@ -114,6 +119,7 @@ def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs): ...@@ -114,6 +119,7 @@ def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
bias=bias, bias=bias,
mask=mask, mask=mask,
deterministic=not kwargs['is_training'], deterministic=not kwargs['is_training'],
scale_factor=kwargs['scaling_factor'],
dropout_rate=kwargs['dropout_probability'], dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng, dropout_rng=dropout_rng,
dtype=jnp.float32) dtype=jnp.float32)
...@@ -149,6 +155,13 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng ...@@ -149,6 +155,13 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
**kwargs).astype(query.dtype) **kwargs).astype(query.dtype)
class BiasShape(Enum):
BIAS_1HSS = '1HSS'
BIAS_B1SS = 'B1SS'
BIAS_BHSS = 'BHSS'
BIAS_11SS = '11SS'
@dataclass @dataclass
class FusedAttnRunner: class FusedAttnRunner:
""" """
...@@ -166,6 +179,7 @@ class FusedAttnRunner: ...@@ -166,6 +179,7 @@ class FusedAttnRunner:
dtype: DTypeLike dtype: DTypeLike
is_training: bool is_training: bool
qkv_layout: QKVLayout qkv_layout: QKVLayout
bias_shape: BiasShape
def _check_configs(self): def _check_configs(self):
if self.qkv_layout == QKVLayout.BS3HD and self.num_heads_q != self.num_heads_kv: if self.qkv_layout == QKVLayout.BS3HD and self.num_heads_q != self.num_heads_kv:
...@@ -174,12 +188,23 @@ class FusedAttnRunner: ...@@ -174,12 +188,23 @@ class FusedAttnRunner:
if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv: if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.") pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")
if not is_fused_attn_kernel_available( self.backend = FusedAttnHelper(
self.dtype, self.dtype, self.qkv_layout, self.attn_bias_type, self.attn_mask_type, self.dtype, self.dtype, self.qkv_layout.value, self.attn_bias_type.value,
self.dropout_prob, self.num_heads_q, self.num_heads_kv, self.max_seqlen_q, self.attn_mask_type.value, self.dropout_prob, self.num_heads_q, self.num_heads_kv,
self.max_seqlen_kv, self.head_dim): self.max_seqlen_q, self.max_seqlen_kv, self.head_dim).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("Unsupported inputs combination or device compute capability.") pytest.skip("Unsupported inputs combination or device compute capability.")
if self.bias_shape != BiasShape.BIAS_1HSS:
if self.attn_bias_type != AttnBiasType.POST_SCALE_BIAS:
pytest.skip("B1SS, BHSS and 11SS bias shapes require POST_SCALE_BIAS.")
elif self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
"AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.")
elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
"the F16_arbitrary_seqlen backend.")
def _setup_inputs(self): def _setup_inputs(self):
self._check_configs() self._check_configs()
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
...@@ -187,14 +212,38 @@ class FusedAttnRunner: ...@@ -187,14 +212,38 @@ class FusedAttnRunner:
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim) q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim) k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim)
if self.bias_shape == BiasShape.BIAS_1HSS:
bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv) bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape.BIAS_B1SS:
bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape.BIAS_BHSS:
bias_shape = (self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape.BIAS_11SS:
bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
else:
pytest.xfail("PyTest attempted to use an unrecognized bias layout!")
self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1) self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.)
self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1) self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.)
self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1) self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.)
with_bias = self.attn_bias_type != AttnBiasType.NO_BIAS if self.attn_bias_type != AttnBiasType.NO_BIAS:
self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1) if with_bias else None if self.bias_shape == BiasShape.BIAS_1HSS:
self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.)
else:
# [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
# an arbitrary mask where (True/False -> 0/-Inf)
cudnn_neg_inf = -2.**27. if self.dtype == jnp.bfloat16 else -2.**15.
self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences
seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
for i in range(1, len(seq_id)):
self.bias = \
self.bias.at[:, :, seq_id[i-1]:seq_id[i], seq_id[i-1]:seq_id[i]].set(0.)
else:
self.bias = None
if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pad_ratio = 0.0 pad_ratio = 0.0
...@@ -233,22 +282,19 @@ class FusedAttnRunner: ...@@ -233,22 +282,19 @@ class FusedAttnRunner:
primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32) primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32)
reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32) reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32)
primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1)
reference_valid, _ = jnp.split(reference_out, (self.valid_len_q,), axis=1)
# Skip elementwise comparison when dropout enabled
if self.is_training and self.dropout_prob > 0.: if self.is_training and self.dropout_prob > 0.:
return return
np.testing.assert_allclose(primitive_valid, reference_valid, atol=1e-2, rtol=1e-4) primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1)
np.testing.assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid)) reference_valid, _ = jnp.split(reference_out, (self.valid_len_q,), axis=1)
assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
def test_backward(self): def test_backward(self):
""" """
Test value_and_grad with JIT, which includes both forward and backward Test value_and_grad with JIT, which includes both forward and backward
""" """
if not self.is_training:
pytest.skip("Backward doesn't support inference")
self._setup_inputs() self._setup_inputs()
...@@ -271,62 +317,71 @@ class FusedAttnRunner: ...@@ -271,62 +317,71 @@ class FusedAttnRunner:
'qkv_layout': self.qkv_layout, 'qkv_layout': self.qkv_layout,
} }
# We can compute dBias only for the [1, h, s, s] layout
arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape.BIAS_1HSS else (0, 1, 2)
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit( jitted_primitive = jit(
value_and_grad( value_and_grad(
lambda q, k, v, bias, *args: grad_func(customcall_fused_dpa, q, k, v, bias, *args, lambda q, k, v, bias, *args: grad_func(customcall_fused_dpa, q, k, v, bias, *args,
**kwargs), (0, 1, 2, 3))) **kwargs), arg_nums))
jitted_reference = jit( jitted_reference = jit(
value_and_grad( value_and_grad(
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs), lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args,
(0, 1, 2, 3))) **kwargs), arg_nums))
primitive_out, primitive_dgrad = jitted_primitive(*args) primitive_out, primitive_dgrad = jitted_primitive(*args)
reference_out, reference_dgrad = jitted_reference(*args) reference_out, reference_dgrad = jitted_reference(*args)
# Skip elementwise comparison when dropout enabled # Skip elementwise comparison when dropout enabled
if self.dropout_prob > 0.: if self.dropout_prob > 0.0:
return return
np.testing.assert_allclose(primitive_out.astype(jnp.float32), assert_allclose(primitive_out.astype(jnp.float32),
reference_out.astype(jnp.float32), reference_out.astype(jnp.float32),
atol=1e-5, dtype=self.dtype)
rtol=1e-3)
# Convert the outputs to float32 for the elementwise comparison
primitive_dq, primitive_dk, primitive_dv, primitive_dbias = map(
jnp.float32, primitive_dgrad)
reference_dq, reference_dk, reference_dv, reference_dbias = map(
jnp.float32, reference_dgrad)
def check_dqkv(primitive, reference, valid_len): def check_dqkv(primitive, reference, valid_len):
primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1) primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1)
reference_valid, reference_invalid = jnp.split(reference, (valid_len,), axis=1) reference_valid, reference_invalid = jnp.split(reference, (valid_len,), axis=1)
np.testing.assert_allclose(primitive_valid, reference_valid, atol=1e-4, rtol=1e-3) assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
assert jnp.allclose(primitive_invalid, reference_invalid) assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
assert jnp.allclose(primitive_invalid, jnp.zeros_like(primitive_invalid)) assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
# Convert the outputs to float32 for the elementwise comparison
primitive_dq, primitive_dk, primitive_dv = map(jnp.float32, primitive_dgrad[:3])
reference_dq, reference_dk, reference_dv = map(jnp.float32, reference_dgrad[:3])
check_dqkv(primitive_dq, reference_dq, self.valid_len_q) check_dqkv(primitive_dq, reference_dq, self.valid_len_q)
check_dqkv(primitive_dk, reference_dk, self.valid_len_kv) check_dqkv(primitive_dk, reference_dk, self.valid_len_kv)
check_dqkv(primitive_dv, reference_dv, self.valid_len_kv) check_dqkv(primitive_dv, reference_dv, self.valid_len_kv)
if self.attn_bias_type != AttnBiasType.NO_BIAS: if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS:
# dbias valid part primitive_dbias = jnp.float32(primitive_dgrad[3])
np.testing.assert_allclose(primitive_dbias[..., :self.valid_len_q, :self.valid_len_kv], reference_dbias = jnp.float32(reference_dgrad[3])
reference_dbias[..., :self.valid_len_q, :self.valid_len_kv],
atol=3e-5,
rtol=1e-4)
# dbias padded part
np.testing.assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
reference_dbias[..., self.valid_len_q:, self.valid_len_kv:])
assert jnp.allclose( assert_allclose(
primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:], primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:])) jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:]),
dtype=self.dtype)
# dbias padded part
assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
reference_dbias[..., self.valid_len_q:, self.valid_len_kv:],
dtype=self.dtype)
# dbias valid part
assert_allclose(primitive_dbias[..., :self.valid_len_q, :self.valid_len_kv],
reference_dbias[..., :self.valid_len_q, :self.valid_len_kv],
dtype=self.dtype)
@pytest.mark.parametrize('bias_shape', [
pytest.param(BiasShape.BIAS_1HSS, id='1-H-S-S'),
pytest.param(BiasShape.BIAS_B1SS, id='B-1-S-S'),
pytest.param(BiasShape.BIAS_BHSS, id='B-H-S-S'),
pytest.param(BiasShape.BIAS_11SS, id='1-1-S-S'),
])
@pytest.mark.parametrize('attn_bias_type', [ @pytest.mark.parametrize('attn_bias_type', [
pytest.param(AttnBiasType.NO_BIAS, id='NO_BIAS'), pytest.param(AttnBiasType.NO_BIAS, id='NO_BIAS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, id='POST_SCALE_BIAS'), pytest.param(AttnBiasType.POST_SCALE_BIAS, id='POST_SCALE_BIAS'),
...@@ -338,42 +393,52 @@ class FusedAttnRunner: ...@@ -338,42 +393,52 @@ class FusedAttnRunner:
pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id='PADDING_CAUSAL'), pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id='PADDING_CAUSAL'),
]) ])
@pytest.mark.parametrize('qkv_layout', [ @pytest.mark.parametrize('qkv_layout', [
pytest.param(QKVLayout.BS3HD, id='qkvpacked'), pytest.param(QKVLayout.BS3HD, id='QKV_PACKED'),
pytest.param(QKVLayout.BSHD_BS2HD, id='kvpacked'), pytest.param(QKVLayout.BSHD_BS2HD, id='KV_PACKED'),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id='separate'), pytest.param(QKVLayout.BSHD_BSHD_BSHD, id='SEPARATE'),
])
@pytest.mark.parametrize('dtype', [
pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16")
])
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d',[
pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'),
pytest.param( 4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'),
pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'),
pytest.param( 4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'),
pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'),
pytest.param( 4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA')
])
@pytest.mark.parametrize('dropout_prob', [
pytest.param(0.0, id="DROP_0.0"),
pytest.param(0.1, id="DROP_0.1")
])
@pytest.mark.parametrize('is_training', [
pytest.param(True, id='TRAINING'),
pytest.param(False, id='INFERENCE'),
]) ])
@pytest.mark.parametrize('dropout_prob', [0., 0.1])
@pytest.mark.parametrize('is_training',
[pytest.param(True, id='training'),
pytest.param(False, id='inference')])
@pytest.mark.parametrize(
'dtype', [pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16")])
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d',
[(32, 128, 128, 16, 16, 64), (4, 2048, 2048, 12, 12, 64),
pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-cross'),
pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA')])
class TestFusedAttn: class TestFusedAttn:
""" """
Fused attention tester Fused attention tester
""" """
@staticmethod @staticmethod
def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob, def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dtype, is_training, qkv_layout): dropout_prob, dtype, is_training, qkv_layout, bias_shape):
""" """
Test forward with parameterized configs Test forward with parameterized configs
""" """
runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, is_training, qkv_layout) dropout_prob, dtype, is_training, qkv_layout, bias_shape)
runner.test_forward() runner.test_forward()
@staticmethod @staticmethod
def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob, def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dtype, is_training, qkv_layout): dropout_prob, dtype, is_training, qkv_layout, bias_shape):
""" """
Test backward with parameterized configs Test backward with parameterized configs
""" """
if not is_training:
pytest.skip("Backward pass does not support inference.")
runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, is_training, qkv_layout) dropout_prob, dtype, True, qkv_layout, bias_shape)
runner.test_backward() runner.test_backward()
...@@ -85,6 +85,7 @@ class ModelConfig: ...@@ -85,6 +85,7 @@ class ModelConfig:
attn_bias_type: str, attn_bias_type: str,
alibi_type: str = "none", alibi_type: str = "none",
num_layers: int = 1, num_layers: int = 1,
bias_shape: str = "1hss",
): ):
self.batch_size = batch_size self.batch_size = batch_size
self.num_heads = num_heads self.num_heads = num_heads
...@@ -100,6 +101,7 @@ class ModelConfig: ...@@ -100,6 +101,7 @@ class ModelConfig:
self.alibi_type = alibi_type self.alibi_type = alibi_type
self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross" self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
self.num_layers = num_layers self.num_layers = num_layers
self.bias_shape = bias_shape
def _is_fused_attention_supported( def _is_fused_attention_supported(
config: ModelConfig, config: ModelConfig,
...@@ -379,6 +381,31 @@ def test_dpa_bias(dtype, model_configs, model): ...@@ -379,6 +381,31 @@ def test_dpa_bias(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types""" """Test DotProductAttention module with different bias types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False) test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
model_configs_bias_shapes = {
# test: b, h, hg, d, sq, skv, p,
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0,
# mask, bias, bias_shape,
"no_mask", "post_scale_bias", bias_shape='11ss'),
"bias_1_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0,
"no_mask", "post_scale_bias", bias_shape='1hss'),
"bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,
"no_mask", "post_scale_bias", bias_shape='b1ss'),
"bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,
"no_mask", "post_scale_bias", bias_shape='bhss'),
"bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,
"causal", "alibi", bias_shape='1hss', alibi_type='custom'),
"bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,
"causal", "alibi", bias_shape='bhss', alibi_type='custom'),
}
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
def test_dpa_bias_shapes(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types and shapes"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
model_configs_swa = { model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), "swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
...@@ -510,10 +537,13 @@ def _run_dot_product_attention( ...@@ -510,10 +537,13 @@ def _run_dot_product_attention(
window_size, attention_mask = None, None window_size, attention_mask = None, None
alibi_slopes = None alibi_slopes = None
if config.attn_bias_type == "alibi": if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.alibi_type == "custom": if config.bias_shape == "1hss":
alibi_slopes = torch.randn( alibi_slopes = torch.randn(
config.num_heads).abs().to(dtype=torch.float32, device="cuda") config.num_heads).abs().to(dtype=torch.float32, device="cuda")
if config.bias_shape == "bhss":
alibi_slopes = torch.randn(
config.batch_size, config.num_heads).abs().to(dtype=torch.float32, device="cuda")
# Create input tensors # Create input tensors
dim_to_num = { dim_to_num = {
...@@ -527,6 +557,7 @@ def _run_dot_product_attention( ...@@ -527,6 +557,7 @@ def _run_dot_product_attention(
'tg' : cu_seqlens_kv[-1], 'tg' : cu_seqlens_kv[-1],
'3' : 3, '3' : 3,
'2' : 2, '2' : 2,
'1' : 1,
} }
inp = [] inp = []
for i,layout in enumerate(qkv_layout.split('_')): for i,layout in enumerate(qkv_layout.split('_')):
...@@ -566,8 +597,12 @@ def _run_dot_product_attention( ...@@ -566,8 +597,12 @@ def _run_dot_product_attention(
if config.attn_bias_type in ['no_bias', 'alibi']: if config.attn_bias_type in ['no_bias', 'alibi']:
bias = None bias = None
if config.attn_bias_type == 'post_scale_bias': if config.attn_bias_type == 'post_scale_bias':
bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv, shape = '_'.join(config.bias_shape)
dtype=dtype, device="cuda") shape = shape.replace('_s_s', '_sq_skv')
tensor_shape = [dim_to_num[j] for j in shape.split('_')]
bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
if config.bias_shape != '1hss':
bias.requires_grad = False
# Create RNG # Create RNG
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
......
...@@ -72,6 +72,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -72,6 +72,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
FADescriptor_v1 descriptor{b, h, FADescriptor_v1 descriptor{b, h,
hg, s_q, hg, s_q,
s_kv, d, s_kv, d,
bias_b, bias_h,
scaling_factor, is_training, scaling_factor, is_training,
dropout_probability, layout, dropout_probability, layout,
bias_type, mask_type, bias_type, mask_type,
...@@ -316,6 +317,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -316,6 +317,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
FADescriptor_v1 descriptor{b, h, FADescriptor_v1 descriptor{b, h,
hg, s_q, hg, s_q,
s_kv, d, s_kv, d,
bias_b, bias_h,
scaling_factor, true, scaling_factor, true,
dropout_probability, layout, dropout_probability, layout,
bias_type, mask_type, bias_type, mask_type,
...@@ -426,8 +428,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -426,8 +428,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_dim({bias_b, bias_h, s_q, s_kv}) .set_dim({bias_b, bias_h, s_q, s_kv})
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
sdpa_backward_options.set_bias(bias); sdpa_backward_options.set_bias(bias);
// shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s]
// are not supported for dbias calculation but they are
// supported for forward bias calculation
if ((bias_b == 1) && (bias_h == h)) {
sdpa_backward_options.set_dbias(dBias); sdpa_backward_options.set_dbias(dBias);
} }
}
if (is_padding) { if (is_padding) {
seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
...@@ -541,7 +548,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -541,7 +548,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
if (is_bias) { if (is_bias) {
variant_pack[bias] = devPtrBias; variant_pack[bias] = devPtrBias;
if ((bias_b == 1) && (bias_h == h)) {
variant_pack[dBias] = devPtrdBias; variant_pack[dBias] = devPtrdBias;
} else {
variant_pack[dBias] = nullptr;
}
} }
if (is_padding) { if (is_padding) {
......
...@@ -103,6 +103,8 @@ struct FADescriptor_v1 { ...@@ -103,6 +103,8 @@ struct FADescriptor_v1 {
std::int64_t s_q; std::int64_t s_q;
std::int64_t s_kv; std::int64_t s_kv;
std::int64_t d; std::int64_t d;
std::int64_t bias_b;
std::int64_t bias_h;
float attnScale; float attnScale;
bool isTraining; bool isTraining;
float dropoutProbability; float dropoutProbability;
...@@ -112,11 +114,12 @@ struct FADescriptor_v1 { ...@@ -112,11 +114,12 @@ struct FADescriptor_v1 {
cudnn_frontend::DataType_t tensor_type; cudnn_frontend::DataType_t tensor_type;
bool operator<(const FADescriptor_v1 &rhs) const { bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d, return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h,
attnScale, isTraining, dropoutProbability, attnScale, isTraining, dropoutProbability,
layout, mask_type, bias_type, tensor_type) layout, mask_type, bias_type, tensor_type)
< std::tie( < std::tie(
rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d, rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d,
rhs.bias_b, rhs.bias_h,
rhs.attnScale, rhs.isTraining, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.bias_type, rhs.mask_type, rhs.bias_type,
......
This diff is collapsed.
This diff is collapsed.
...@@ -105,11 +105,13 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t paddin ...@@ -105,11 +105,13 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t paddin
DType dtype, float scale_factor); DType dtype, float scale_factor);
struct CustomCallFusedAttnDescriptor { struct CustomCallFusedAttnDescriptor {
size_t batch_size; size_t input_batch;
size_t bias_batch;
size_t q_max_seqlen; size_t q_max_seqlen;
size_t kv_max_seqlen; size_t kv_max_seqlen;
size_t num_heads; size_t attn_heads;
size_t num_gqa_groups; size_t num_gqa_groups;
size_t bias_heads;
size_t head_dim; size_t head_dim;
size_t wkspace_size; size_t wkspace_size;
float scaling_factor; float scaling_factor;
...@@ -122,10 +124,11 @@ struct CustomCallFusedAttnDescriptor { ...@@ -122,10 +124,11 @@ struct CustomCallFusedAttnDescriptor {
}; };
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_gqa_groups, size_t head_dim, size_t wkspace_size, float scaling_factor, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, size_t wkspace_size, float scaling_factor, float dropout_probability,
DType wkspace_dtype, bool is_training); NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
DType dtype, DType wkspace_dtype, bool is_training);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
...@@ -205,47 +208,53 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, ...@@ -205,47 +208,53 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
std::size_t opaque_len); std::size_t opaque_len);
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes( pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, float scaling_factor, size_t input_batch, size_t bias_batch, size_t max_seqlen,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, size_t attn_heads, size_t bias_heads, size_t head_dim,
bool is_training); float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, float scaling_factor, size_t input_batch, size_t bias_batch, size_t max_seqlen,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, size_t attn_heads, size_t bias_heads, size_t head_dim,
bool is_training); float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes( pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes( pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
...@@ -2784,10 +2784,15 @@ class DotProductAttention(torch.nn.Module): ...@@ -2784,10 +2784,15 @@ class DotProductAttention(torch.nn.Module):
_, fu_core_attention_bias = get_alibi( _, fu_core_attention_bias = get_alibi(
query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes, query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes,
bias_dtype=query_layer.dtype) bias_dtype=query_layer.dtype)
if (fu_core_attention_bias.shape[0] != 1 if (use_fused_attention
or fu_core_attention_bias.shape[1] != query_layer.shape[-2]): and fu_core_attention_bias_type == "post_scale_bias"
# remove this line when cuDNN adds bwd support for [b, 1, s, s] and [b, h, s, s] and (fu_core_attention_bias.shape[0] != 1
or fu_core_attention_bias.shape[1] != query_layer.shape[-2])):
if fu_core_attention_bias.requires_grad:
# remove this line when cuDNN adds bwd support for
# [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
use_fused_attention = False use_fused_attention = False
else:
# max512 backend will only support [1, h, s, s] # max512 backend will only support [1, h, s, s]
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
...@@ -2812,6 +2817,11 @@ class DotProductAttention(torch.nn.Module): ...@@ -2812,6 +2817,11 @@ class DotProductAttention(torch.nn.Module):
use_fused_attention and is_backend_avail and \ use_fused_attention and is_backend_avail and \
(not context_parallel or \ (not context_parallel or \
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"])) fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]))
if (fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
and fu_core_attention_bias_type == "post_scale_bias"
and (fu_core_attention_bias.shape[0] != 1
or fu_core_attention_bias.shape[1] != query_layer.shape[-2])):
use_fused_attention = False
# Filter: determinism. # Filter: determinism.
# backend | deterministic # backend | deterministic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment