Unverified Commit b8eea8aa authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[C/PyTorch/Jax] Add support for more bias shapes (#677)



* added support for arbitrary bias shapes for fused_attn
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* Fix linting
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* Add b1ss/bhss/11ss bias shapes when not requiring dBias
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add bias_b/h to plan cache
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fixed compile errors after PR653 merge
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* updated JAX unittests for new bias shapes
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed mismatched mask type checking
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected skip condition
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fix selection logic for A100s
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* corrected skip checks for bias shapes
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* resolved test issues but neginf with float16 is still problematic with JAX
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* new bias shapes passing TE JAX CI for seqlen <= 512, seq_q == seq_kv and h_q == h_kv conditions
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* TE/JAX fused attn tests for new bias shapes passing with neg_inf=-2**27 for Bfloat16 and -2**15 for Float16
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* code style fixes and test parameter ID cleanup
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed incorrect skip condition for backward fused attn test
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarAlp Dener <adener@nvidia.com>
parent 04040957
......@@ -2,14 +2,15 @@
#
# See LICENSE for license information.
"""Tests for fused attention"""
import sys
from enum import Enum
from dataclasses import dataclass
from functools import partial
from math import sqrt
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from flax.linen import combine_masks
......@@ -21,7 +22,11 @@ from jax.typing import ArrayLike, DTypeLike
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn, fused_attn
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine_jax import NVTE_Fused_Attn_Backend
from utils import assert_allclose
@pytest.fixture(autouse=True, scope='function')
......@@ -38,7 +43,7 @@ def clear_live_arrays():
def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: ArrayLike,
bias: ArrayLike, mask: ArrayLike, deterministic: bool,
dropout_rate: float, dropout_rng: ArrayLike,
scale_factor: float, dropout_rate: float, dropout_rng: ArrayLike,
dtype: DTypeLike) -> Array:
"""
Similar to flax.linen.dot_product_attention but with GQA support
......@@ -46,21 +51,21 @@ def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: Array
query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
dtype = query.dtype
depth = query.shape[-1]
query = query / jnp.sqrt(depth).astype(dtype)
b, s_q, h_q, d = query.shape
_, _, h_kv, _ = key.shape
_, s_kv, h_kv, _ = key.shape
assert (h_q % h_kv == 0) and (h_q >= h_kv)
num_groups = h_q // h_kv
grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
# logits with shape (b, h_kv, num_groups, s_q, s_kv)
logits = jnp.einsum('...qhgd,...khd->...hgqk', grouped_query, key)
logits = scale_factor * jnp.einsum('...qhgd,...khd->...hgqk', grouped_query, key)
if bias is not None:
if bias.ndim != logits.ndim:
bias = bias.reshape((1, *logits.shape[1:]))
# reshape logits without groups
logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
# apply post-scale bias
logits = logits + bias
# reshape logits back to original
logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
if mask is not None:
if mask.ndim != logits.ndim:
......@@ -114,6 +119,7 @@ def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
bias=bias,
mask=mask,
deterministic=not kwargs['is_training'],
scale_factor=kwargs['scaling_factor'],
dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng,
dtype=jnp.float32)
......@@ -149,6 +155,13 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
**kwargs).astype(query.dtype)
class BiasShape(Enum):
BIAS_1HSS = '1HSS'
BIAS_B1SS = 'B1SS'
BIAS_BHSS = 'BHSS'
BIAS_11SS = '11SS'
@dataclass
class FusedAttnRunner:
"""
......@@ -166,6 +179,7 @@ class FusedAttnRunner:
dtype: DTypeLike
is_training: bool
qkv_layout: QKVLayout
bias_shape: BiasShape
def _check_configs(self):
if self.qkv_layout == QKVLayout.BS3HD and self.num_heads_q != self.num_heads_kv:
......@@ -174,12 +188,23 @@ class FusedAttnRunner:
if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")
if not is_fused_attn_kernel_available(
self.dtype, self.dtype, self.qkv_layout, self.attn_bias_type, self.attn_mask_type,
self.dropout_prob, self.num_heads_q, self.num_heads_kv, self.max_seqlen_q,
self.max_seqlen_kv, self.head_dim):
self.backend = FusedAttnHelper(
self.dtype, self.dtype, self.qkv_layout.value, self.attn_bias_type.value,
self.attn_mask_type.value, self.dropout_prob, self.num_heads_q, self.num_heads_kv,
self.max_seqlen_q, self.max_seqlen_kv, self.head_dim).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("Unsupported inputs combination or device compute capability.")
if self.bias_shape != BiasShape.BIAS_1HSS:
if self.attn_bias_type != AttnBiasType.POST_SCALE_BIAS:
pytest.skip("B1SS, BHSS and 11SS bias shapes require POST_SCALE_BIAS.")
elif self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
"AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.")
elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
"the F16_arbitrary_seqlen backend.")
def _setup_inputs(self):
self._check_configs()
key = jax.random.PRNGKey(0)
......@@ -187,14 +212,38 @@ class FusedAttnRunner:
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim)
if self.bias_shape == BiasShape.BIAS_1HSS:
bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape.BIAS_B1SS:
bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape.BIAS_BHSS:
bias_shape = (self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape.BIAS_11SS:
bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
else:
pytest.xfail("PyTest attempted to use an unrecognized bias layout!")
self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1)
self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1)
self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1)
self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.)
self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.)
self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.)
with_bias = self.attn_bias_type != AttnBiasType.NO_BIAS
self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1) if with_bias else None
if self.attn_bias_type != AttnBiasType.NO_BIAS:
if self.bias_shape == BiasShape.BIAS_1HSS:
self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.)
else:
# [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
# an arbitrary mask where (True/False -> 0/-Inf)
cudnn_neg_inf = -2.**27. if self.dtype == jnp.bfloat16 else -2.**15.
self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences
seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
for i in range(1, len(seq_id)):
self.bias = \
self.bias.at[:, :, seq_id[i-1]:seq_id[i], seq_id[i-1]:seq_id[i]].set(0.)
else:
self.bias = None
if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pad_ratio = 0.0
......@@ -233,22 +282,19 @@ class FusedAttnRunner:
primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32)
reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32)
primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1)
reference_valid, _ = jnp.split(reference_out, (self.valid_len_q,), axis=1)
# Skip elementwise comparison when dropout enabled
if self.is_training and self.dropout_prob > 0.:
return
np.testing.assert_allclose(primitive_valid, reference_valid, atol=1e-2, rtol=1e-4)
np.testing.assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid))
primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1)
reference_valid, _ = jnp.split(reference_out, (self.valid_len_q,), axis=1)
assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
def test_backward(self):
"""
Test value_and_grad with JIT, which includes both forward and backward
"""
if not self.is_training:
pytest.skip("Backward doesn't support inference")
self._setup_inputs()
......@@ -271,62 +317,71 @@ class FusedAttnRunner:
'qkv_layout': self.qkv_layout,
}
# We can compute dBias only for the [1, h, s, s] layout
arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape.BIAS_1HSS else (0, 1, 2)
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
value_and_grad(
lambda q, k, v, bias, *args: grad_func(customcall_fused_dpa, q, k, v, bias, *args,
**kwargs), (0, 1, 2, 3)))
**kwargs), arg_nums))
jitted_reference = jit(
value_and_grad(
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
(0, 1, 2, 3)))
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args,
**kwargs), arg_nums))
primitive_out, primitive_dgrad = jitted_primitive(*args)
reference_out, reference_dgrad = jitted_reference(*args)
# Skip elementwise comparison when dropout enabled
if self.dropout_prob > 0.:
if self.dropout_prob > 0.0:
return
np.testing.assert_allclose(primitive_out.astype(jnp.float32),
assert_allclose(primitive_out.astype(jnp.float32),
reference_out.astype(jnp.float32),
atol=1e-5,
rtol=1e-3)
# Convert the outputs to float32 for the elementwise comparison
primitive_dq, primitive_dk, primitive_dv, primitive_dbias = map(
jnp.float32, primitive_dgrad)
reference_dq, reference_dk, reference_dv, reference_dbias = map(
jnp.float32, reference_dgrad)
dtype=self.dtype)
def check_dqkv(primitive, reference, valid_len):
primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1)
reference_valid, reference_invalid = jnp.split(reference, (valid_len,), axis=1)
np.testing.assert_allclose(primitive_valid, reference_valid, atol=1e-4, rtol=1e-3)
assert jnp.allclose(primitive_invalid, reference_invalid)
assert jnp.allclose(primitive_invalid, jnp.zeros_like(primitive_invalid))
assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
# Convert the outputs to float32 for the elementwise comparison
primitive_dq, primitive_dk, primitive_dv = map(jnp.float32, primitive_dgrad[:3])
reference_dq, reference_dk, reference_dv = map(jnp.float32, reference_dgrad[:3])
check_dqkv(primitive_dq, reference_dq, self.valid_len_q)
check_dqkv(primitive_dk, reference_dk, self.valid_len_kv)
check_dqkv(primitive_dv, reference_dv, self.valid_len_kv)
if self.attn_bias_type != AttnBiasType.NO_BIAS:
# dbias valid part
np.testing.assert_allclose(primitive_dbias[..., :self.valid_len_q, :self.valid_len_kv],
reference_dbias[..., :self.valid_len_q, :self.valid_len_kv],
atol=3e-5,
rtol=1e-4)
# dbias padded part
np.testing.assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
reference_dbias[..., self.valid_len_q:, self.valid_len_kv:])
if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS:
primitive_dbias = jnp.float32(primitive_dgrad[3])
reference_dbias = jnp.float32(reference_dgrad[3])
assert jnp.allclose(
assert_allclose(
primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:]))
jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:]),
dtype=self.dtype)
# dbias padded part
assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
reference_dbias[..., self.valid_len_q:, self.valid_len_kv:],
dtype=self.dtype)
# dbias valid part
assert_allclose(primitive_dbias[..., :self.valid_len_q, :self.valid_len_kv],
reference_dbias[..., :self.valid_len_q, :self.valid_len_kv],
dtype=self.dtype)
@pytest.mark.parametrize('bias_shape', [
pytest.param(BiasShape.BIAS_1HSS, id='1-H-S-S'),
pytest.param(BiasShape.BIAS_B1SS, id='B-1-S-S'),
pytest.param(BiasShape.BIAS_BHSS, id='B-H-S-S'),
pytest.param(BiasShape.BIAS_11SS, id='1-1-S-S'),
])
@pytest.mark.parametrize('attn_bias_type', [
pytest.param(AttnBiasType.NO_BIAS, id='NO_BIAS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, id='POST_SCALE_BIAS'),
......@@ -338,42 +393,52 @@ class FusedAttnRunner:
pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id='PADDING_CAUSAL'),
])
@pytest.mark.parametrize('qkv_layout', [
pytest.param(QKVLayout.BS3HD, id='qkvpacked'),
pytest.param(QKVLayout.BSHD_BS2HD, id='kvpacked'),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id='separate'),
pytest.param(QKVLayout.BS3HD, id='QKV_PACKED'),
pytest.param(QKVLayout.BSHD_BS2HD, id='KV_PACKED'),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id='SEPARATE'),
])
@pytest.mark.parametrize('dtype', [
pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16")
])
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d',[
pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'),
pytest.param( 4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'),
pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'),
pytest.param( 4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'),
pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'),
pytest.param( 4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA')
])
@pytest.mark.parametrize('dropout_prob', [
pytest.param(0.0, id="DROP_0.0"),
pytest.param(0.1, id="DROP_0.1")
])
@pytest.mark.parametrize('is_training', [
pytest.param(True, id='TRAINING'),
pytest.param(False, id='INFERENCE'),
])
@pytest.mark.parametrize('dropout_prob', [0., 0.1])
@pytest.mark.parametrize('is_training',
[pytest.param(True, id='training'),
pytest.param(False, id='inference')])
@pytest.mark.parametrize(
'dtype', [pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16")])
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d',
[(32, 128, 128, 16, 16, 64), (4, 2048, 2048, 12, 12, 64),
pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-cross'),
pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA')])
class TestFusedAttn:
"""
Fused attention tester
"""
@staticmethod
def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
dtype, is_training, qkv_layout):
def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, is_training, qkv_layout, bias_shape):
"""
Test forward with parameterized configs
"""
runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, is_training, qkv_layout)
dropout_prob, dtype, is_training, qkv_layout, bias_shape)
runner.test_forward()
@staticmethod
def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
dtype, is_training, qkv_layout):
def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, is_training, qkv_layout, bias_shape):
"""
Test backward with parameterized configs
"""
if not is_training:
pytest.skip("Backward pass does not support inference.")
runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, is_training, qkv_layout)
dropout_prob, dtype, True, qkv_layout, bias_shape)
runner.test_backward()
......@@ -85,6 +85,7 @@ class ModelConfig:
attn_bias_type: str,
alibi_type: str = "none",
num_layers: int = 1,
bias_shape: str = "1hss",
):
self.batch_size = batch_size
self.num_heads = num_heads
......@@ -100,6 +101,7 @@ class ModelConfig:
self.alibi_type = alibi_type
self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
self.num_layers = num_layers
self.bias_shape = bias_shape
def _is_fused_attention_supported(
config: ModelConfig,
......@@ -379,6 +381,31 @@ def test_dpa_bias(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
model_configs_bias_shapes = {
# test: b, h, hg, d, sq, skv, p,
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0,
# mask, bias, bias_shape,
"no_mask", "post_scale_bias", bias_shape='11ss'),
"bias_1_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0,
"no_mask", "post_scale_bias", bias_shape='1hss'),
"bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,
"no_mask", "post_scale_bias", bias_shape='b1ss'),
"bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,
"no_mask", "post_scale_bias", bias_shape='bhss'),
"bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,
"causal", "alibi", bias_shape='1hss', alibi_type='custom'),
"bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,
"causal", "alibi", bias_shape='bhss', alibi_type='custom'),
}
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
def test_dpa_bias_shapes(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types and shapes"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
......@@ -510,10 +537,13 @@ def _run_dot_product_attention(
window_size, attention_mask = None, None
alibi_slopes = None
if config.attn_bias_type == "alibi":
if config.alibi_type == "custom":
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.bias_shape == "1hss":
alibi_slopes = torch.randn(
config.num_heads).abs().to(dtype=torch.float32, device="cuda")
if config.bias_shape == "bhss":
alibi_slopes = torch.randn(
config.batch_size, config.num_heads).abs().to(dtype=torch.float32, device="cuda")
# Create input tensors
dim_to_num = {
......@@ -527,6 +557,7 @@ def _run_dot_product_attention(
'tg' : cu_seqlens_kv[-1],
'3' : 3,
'2' : 2,
'1' : 1,
}
inp = []
for i,layout in enumerate(qkv_layout.split('_')):
......@@ -566,8 +597,12 @@ def _run_dot_product_attention(
if config.attn_bias_type in ['no_bias', 'alibi']:
bias = None
if config.attn_bias_type == 'post_scale_bias':
bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv,
dtype=dtype, device="cuda")
shape = '_'.join(config.bias_shape)
shape = shape.replace('_s_s', '_sq_skv')
tensor_shape = [dim_to_num[j] for j in shape.split('_')]
bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
if config.bias_shape != '1hss':
bias.requires_grad = False
# Create RNG
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
......
......@@ -72,6 +72,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
FADescriptor_v1 descriptor{b, h,
hg, s_q,
s_kv, d,
bias_b, bias_h,
scaling_factor, is_training,
dropout_probability, layout,
bias_type, mask_type,
......@@ -316,6 +317,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
FADescriptor_v1 descriptor{b, h,
hg, s_q,
s_kv, d,
bias_b, bias_h,
scaling_factor, true,
dropout_probability, layout,
bias_type, mask_type,
......@@ -426,8 +428,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_dim({bias_b, bias_h, s_q, s_kv})
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
sdpa_backward_options.set_bias(bias);
// shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s]
// are not supported for dbias calculation but they are
// supported for forward bias calculation
if ((bias_b == 1) && (bias_h == h)) {
sdpa_backward_options.set_dbias(dBias);
}
}
if (is_padding) {
seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
......@@ -541,7 +548,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
if (is_bias) {
variant_pack[bias] = devPtrBias;
if ((bias_b == 1) && (bias_h == h)) {
variant_pack[dBias] = devPtrdBias;
} else {
variant_pack[dBias] = nullptr;
}
}
if (is_padding) {
......
......@@ -103,6 +103,8 @@ struct FADescriptor_v1 {
std::int64_t s_q;
std::int64_t s_kv;
std::int64_t d;
std::int64_t bias_b;
std::int64_t bias_h;
float attnScale;
bool isTraining;
float dropoutProbability;
......@@ -112,11 +114,12 @@ struct FADescriptor_v1 {
cudnn_frontend::DataType_t tensor_type;
bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d,
return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h,
attnScale, isTraining, dropoutProbability,
layout, mask_type, bias_type, tensor_type)
< std::tie(
rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d,
rhs.bias_b, rhs.bias_h,
rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.bias_type,
......
......@@ -1910,23 +1910,23 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
# outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
del seqlen_or_cu_seqlen_aval
qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
*batch_shape, max_seqlen, nqkv, num_heads, head_dim = qkv_aval.shape
*input_batch_shape, max_seqlen, nqkv, attn_heads, head_dim = qkv_aval.shape
assert nqkv == 3
assert qkv_aval.dtype == bias_aval.dtype
output_shape = (*batch_shape, max_seqlen, num_heads, head_dim)
output_shape = (*input_batch_shape, max_seqlen, attn_heads, head_dim)
out_aval = qkv_aval.update(shape=output_shape, dtype=qkv_dtype)
# backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type,
attn_mask_type, dropout_probability, num_heads, num_heads,
attn_mask_type, dropout_probability, attn_heads, attn_heads,
max_seqlen, max_seqlen, head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_shape = (*batch_shape, num_heads, max_seqlen, max_seqlen)
softmax_shape = (*input_batch_shape, attn_heads, max_seqlen, max_seqlen)
softmax_dtype = qkv_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_shape = (*batch_shape, num_heads, max_seqlen, 1)
softmax_shape = (*input_batch_shape, attn_heads, max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f'Unsupported {backend=}')
......@@ -1940,12 +1940,19 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
# prepare for the active fused-attn backend
batch_size = reduce(operator.mul, batch_shape)
input_batch = reduce(operator.mul, input_batch_shape)
wkspace_info = transformer_engine_jax.get_self_fused_attn_fwd_workspace_sizes(
batch_size, max_seqlen, num_heads, head_dim, scaling_factor, dropout_probability,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
input_batch, bias_batch, max_seqlen, attn_heads, bias_heads, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
wkspace_aval = qkv_aval.update(shape=wkspace_info[0],
dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
......@@ -1974,14 +1981,21 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
qkv_aval = ctx.avals_in[0]
*batch_shape, max_seqlen, _, num_heads, head_dim = qkv_aval.shape
batch_size = reduce(operator.mul, batch_shape)
qkv_aval, bias_aval, *_ = ctx.avals_in
*input_batch_shape, max_seqlen, _, attn_heads, head_dim = qkv_aval.shape
input_batch = reduce(operator.mul, input_batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch_size, max_seqlen, max_seqlen, num_heads, num_heads, head_dim, wkspace_aval.size,
input_batch, bias_batch, max_seqlen, max_seqlen,
attn_heads, attn_heads, bias_heads, head_dim, wkspace_aval.size,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
......@@ -2074,6 +2088,7 @@ def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, seqlen: jnp.ndarray
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=qkv.dtype)
return SelfFusedAttnFwdPrimitive.outer_primitive.bind(qkv,
bias,
seqlen,
......@@ -2105,15 +2120,21 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
del softmax_aux_aval, rng_state_aval, seqlen_or_cu_seqlen_aval
assert qkv_aval.dtype == bias_aval.dtype == output_aval.dtype == doutput_aval.dtype
*batch_shape, max_seqlen, nqkv, num_heads, head_dim = qkv_aval.shape
*input_batch_shape, max_seqlen, nqkv, attn_heads, head_dim = qkv_aval.shape
assert nqkv == 3
qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
batch_size = reduce(operator.mul, batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
input_batch = reduce(operator.mul, input_batch_shape)
wkspace_shape, wkspace_dtype = \
transformer_engine_jax.get_self_fused_attn_bwd_workspace_sizes(
batch_size, max_seqlen, num_heads, head_dim,
input_batch, bias_batch, max_seqlen, attn_heads, bias_heads, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), is_training
)
......@@ -2147,14 +2168,21 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
qkv_aval = ctx.avals_in[0]
*batch_shape, max_seqlen, _, num_heads, head_dim = qkv_aval.shape
batch_size = reduce(operator.mul, batch_shape)
qkv_aval, bias_aval, *_ = ctx.avals_in
*input_batch_shape, max_seqlen, _, attn_heads, head_dim = qkv_aval.shape
input_batch = reduce(operator.mul, input_batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch_size, max_seqlen, max_seqlen, num_heads, num_heads, head_dim, wkspace_aval.size,
input_batch, bias_batch, max_seqlen, max_seqlen,
attn_heads, attn_heads, bias_heads, head_dim, wkspace_aval.size,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(qkv_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
......@@ -2261,6 +2289,7 @@ def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.nd
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=qkv.dtype)
return SelfFusedAttnBwdPrimitive.outer_primitive.bind(qkv,
bias,
softmax_aux,
......@@ -2298,7 +2327,7 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
assert q_dtype == kv_dtype == bias_dtype
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, num_heads, q_head_dim = q_aval.shape
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = kv_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
......@@ -2307,15 +2336,15 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
# backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
attn_bias_type, attn_mask_type, dropout_probability, num_heads,
attn_bias_type, attn_mask_type, dropout_probability, attn_heads,
num_gqa_groups, q_max_seqlen, kv_max_seqlen,
q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, kv_max_seqlen)
softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
softmax_dtype = q_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, 1)
softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f'Unsupported {backend=}')
......@@ -2329,11 +2358,18 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
# prepare for the active fused-attn backend
batch_size = reduce(operator.mul, q_batch_shape)
input_batch = reduce(operator.mul, q_batch_shape)
wkspace_info = transformer_engine_jax.get_cross_fused_attn_fwd_workspace_sizes(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_head_dim,
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
wkspace_aval = q_aval.update(shape=wkspace_info[0],
......@@ -2364,15 +2400,22 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, kv_aval, *_ = ctx.avals_in
*batch_shape, q_max_seqlen, num_heads, head_dim = q_aval.shape
q_aval, kv_aval, bias_aval, *_ = ctx.avals_in
*input_batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, _, num_gqa_groups, _ = kv_aval.shape
batch_size = reduce(operator.mul, batch_shape)
input_batch = reduce(operator.mul, input_batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim,
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, head_dim,
wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
......@@ -2512,16 +2555,23 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
assert q_dtype == kv_dtype == bias_dtype == doutput_dtype
assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, num_heads, q_head_dim = q_aval.shape
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = kv_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert nkv == 2
batch_size = reduce(operator.mul, q_batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
input_batch = reduce(operator.mul, q_batch_shape)
wkspace_shape, wkspace_dtype = \
transformer_engine_jax.get_cross_fused_attn_bwd_workspace_sizes(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_head_dim,
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training
)
......@@ -2559,15 +2609,22 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, kv_aval, *_ = ctx.avals_in
*batch_shape, q_max_seqlen, num_heads, head_dim = q_aval.shape
q_aval, kv_aval, bias_aval, *_ = ctx.avals_in
*input_batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, _, num_gqa_groups, _ = kv_aval.shape
batch_size = reduce(operator.mul, batch_shape)
input_batch = reduce(operator.mul, input_batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim,
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, head_dim,
wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
......@@ -2684,6 +2741,7 @@ def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
return CrossFusedAttnBwdPrimitive.outer_primitive.bind(q,
kv,
bias,
......@@ -2725,7 +2783,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
assert q_dtype == k_dtype == v_dtype == bias_dtype
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, num_heads, q_head_dim = q_aval.shape
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
......@@ -2734,14 +2792,14 @@ class FusedAttnFwdPrimitive(BasePrimitive):
# backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(q_dtype, k_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD, attn_bias_type,
attn_mask_type, dropout_probability, num_heads, num_gqa_groups,
attn_mask_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, kv_max_seqlen)
softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
softmax_dtype = q_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, 1)
softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f'Unsupported {backend=}')
......@@ -2755,11 +2813,18 @@ class FusedAttnFwdPrimitive(BasePrimitive):
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
# prepare for the active fused-attn backend
batch_size = reduce(operator.mul, q_batch_shape)
input_batch = reduce(operator.mul, q_batch_shape)
wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_head_dim,
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
wkspace_aval = q_aval.update(shape=wkspace_info[0],
......@@ -2790,16 +2855,23 @@ class FusedAttnFwdPrimitive(BasePrimitive):
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, k_aval, v_aval, *_ = ctx.avals_in
*batch_shape, q_max_seqlen, num_heads, head_dim = q_aval.shape
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
*batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, num_gqa_groups, _ = k_aval.shape
assert k_aval.shape == v_aval.shape
batch_size = reduce(operator.mul, batch_shape)
input_batch = reduce(operator.mul, batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim,
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, head_dim,
wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
......@@ -2941,16 +3013,23 @@ class FusedAttnBwdPrimitive(BasePrimitive):
assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype
*q_batch_shape, q_max_seqlen, num_heads, q_head_dim = q_aval.shape
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert k_aval.shape == v_aval.shape
batch_size = reduce(operator.mul, q_batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
input_batch = reduce(operator.mul, q_batch_shape)
wkspace_shape, wkspace_dtype = \
transformer_engine_jax.get_fused_attn_bwd_workspace_sizes(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_head_dim,
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, q_head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training
)
......@@ -2991,16 +3070,23 @@ class FusedAttnBwdPrimitive(BasePrimitive):
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, k_aval, v_aval, *_ = ctx.avals_in
*batch_shape, q_max_seqlen, num_heads, head_dim = q_aval.shape
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
*batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape
*_, kv_max_seqlen, num_gqa_groups, _ = k_aval.shape
assert k_aval.shape == v_aval.shape
batch_size = reduce(operator.mul, batch_shape)
input_batch = reduce(operator.mul, batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim,
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, head_dim,
wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training)
......@@ -3122,6 +3208,7 @@ def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
return FusedAttnBwdPrimitive.outer_primitive.bind(q,
k,
v,
......
......@@ -94,14 +94,15 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t paddin
}
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
DType wkspace_dtype, bool is_training) {
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t wkspace_size, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
DType dtype, DType wkspace_dtype, bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{
batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim, wkspace_size,
scaling_factor, dropout_probability, bias_type, mask_type, dtype, wkspace_dtype,
is_training});
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
bias_heads, head_dim, wkspace_size, scaling_factor, dropout_probability,
bias_type, mask_type, dtype, wkspace_dtype, is_training});
}
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
......@@ -962,8 +963,10 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack,
NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend,
void *softmax_buf, void *rng_state_buf = nullptr,
void *bias_buf = nullptr) {
auto batch_size = desc->batch_size;
auto num_heads = desc->num_heads;
auto input_batch = desc->input_batch;
auto bias_batch = desc->bias_batch;
auto attn_heads = desc->attn_heads;
auto bias_heads = desc->bias_heads;
auto q_max_seqlen = desc->q_max_seqlen;
auto kv_max_seqlen = desc->kv_max_seqlen;
......@@ -973,7 +976,7 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack,
Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
softmax_aux->data.dptr = softmax_buf;
softmax_aux->data.shape =
std::vector<size_t>{batch_size, num_heads, q_max_seqlen, kv_max_seqlen};
std::vector<size_t>{input_batch, attn_heads, q_max_seqlen, kv_max_seqlen};
softmax_aux->data.dtype = desc->dtype;
// arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax
......@@ -992,7 +995,8 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack,
tensor_pack->size = 3;
Tensor *bias_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[2]);
bias_aux->data.dptr = bias_buf;
bias_aux->data.shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
bias_aux->data.shape =
std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
bias_aux->data.dtype = desc->dtype;
}
}
......@@ -1026,26 +1030,27 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack,
}
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
bool is_training) {
size_t input_batch, size_t bias_batch, size_t max_seqlen,
size_t attn_heads, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto qkv_shape = std::vector<size_t>{batch_size * max_seqlen, 3, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, max_seqlen, max_seqlen};
auto qkv_shape = std::vector<size_t>{input_batch * max_seqlen, 3, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, max_seqlen, max_seqlen};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto o_tensor = TensorWrapper(
nullptr, std::vector<size_t>{batch_size * max_seqlen, num_heads, head_dim}, dtype);
nullptr, std::vector<size_t>{input_batch, max_seqlen, attn_heads, head_dim}, dtype);
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_heads, num_heads, max_seqlen, max_seqlen, head_dim);
mask_type, dropout_probability, attn_heads, attn_heads, max_seqlen, max_seqlen, head_dim);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
......@@ -1079,35 +1084,37 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
void *workspace = buffers[7];
// tensor sizes
auto batch_size = descriptor.batch_size;
auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch;
auto max_seqlen = descriptor.q_max_seqlen;
auto num_heads = descriptor.num_heads;
auto attn_heads = descriptor.attn_heads;
auto bias_heads = descriptor.bias_heads;
auto head_dim = descriptor.head_dim;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{batch_size * max_seqlen, 3, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, max_seqlen, max_seqlen};
auto qkv_shape = std::vector<size_t>{input_batch * max_seqlen, 3, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, max_seqlen, max_seqlen};
// input tensors
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
// output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
auto o_tensor = TensorWrapper(
output, std::vector<size_t>{batch_size * max_seqlen, num_heads, head_dim}, dtype);
output, std::vector<size_t>{input_batch * max_seqlen, attn_heads, head_dim}, dtype);
// prep RNG state
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_heads, num_heads, max_seqlen, max_seqlen, head_dim);
mask_type, dropout_probability, attn_heads, attn_heads, max_seqlen, max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, max_seqlen, max_seqlen, backend, stream);
// auxiliary tensors (to be propagated to the backward pass later)
......@@ -1131,14 +1138,15 @@ void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaqu
}
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
bool is_training) {
size_t input_batch, size_t bias_batch, size_t max_seqlen,
size_t attn_heads, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto qkv_shape = std::vector<size_t>{batch_size * max_seqlen, 3, num_heads, head_dim};
auto output_shape = std::vector<size_t>{batch_size * max_seqlen, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, max_seqlen, max_seqlen};
auto qkv_shape = std::vector<size_t>{input_batch * max_seqlen, 3, attn_heads, head_dim};
auto output_shape = std::vector<size_t>{input_batch * max_seqlen, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, max_seqlen, max_seqlen};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
......@@ -1150,7 +1158,7 @@ pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
......@@ -1188,9 +1196,11 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
void *workspace = buffers[9];
// tensor sizes
auto batch_size = descriptor.batch_size;
auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch;
auto max_seqlen = descriptor.q_max_seqlen;
auto num_heads = descriptor.num_heads;
auto attn_heads = descriptor.attn_heads;
auto bias_heads = descriptor.bias_heads;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
......@@ -1198,9 +1208,9 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
auto mask_type = descriptor.mask_type;
auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{batch_size * max_seqlen, 3, num_heads, head_dim};
auto output_shape = std::vector<size_t>{batch_size * max_seqlen, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, max_seqlen, max_seqlen};
auto qkv_shape = std::vector<size_t>{input_batch * max_seqlen, 3, attn_heads, head_dim};
auto output_shape = std::vector<size_t>{input_batch * max_seqlen, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, max_seqlen, max_seqlen};
// input tensors
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
......@@ -1212,7 +1222,7 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
// auxiliary tensors (propagated from the forward pass)
NVTETensorPack aux_input_tensors;
......@@ -1220,7 +1230,7 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_heads, num_heads, max_seqlen, max_seqlen, head_dim);
mask_type, dropout_probability, attn_heads, attn_heads, max_seqlen, max_seqlen, head_dim);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias);
......@@ -1241,14 +1251,15 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
}
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto kv_shape = std::vector<size_t>{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
......@@ -1260,9 +1271,9 @@ pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
......@@ -1301,20 +1312,22 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
void *workspace = buffers[9];
// tensor sizes
auto batch_size = descriptor.batch_size;
auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto num_heads = descriptor.num_heads;
auto attn_heads = descriptor.attn_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto bias_heads = descriptor.bias_heads;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto kv_shape = std::vector<size_t>{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
......@@ -1326,16 +1339,16 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in FP16/BF16
auto o_tensor = TensorWrapper(output, q_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
// prep RNG state
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
......@@ -1360,15 +1373,16 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
}
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto kv_shape = std::vector<size_t>{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
......@@ -1382,9 +1396,9 @@ pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
......@@ -1426,21 +1440,23 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
void *workspace = buffers[12];
// tensor sizes
auto batch_size = descriptor.batch_size;
auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto num_heads = descriptor.num_heads;
auto attn_heads = descriptor.attn_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto bias_heads = descriptor.bias_heads;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto kv_shape = std::vector<size_t>{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
......@@ -1455,9 +1471,9 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
// auxiliary tensors (propagated from the forward pass)
NVTETensorPack aux_input_tensors;
......@@ -1465,7 +1481,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias);
......@@ -1488,15 +1504,16 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
}
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
......@@ -1509,9 +1526,9 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
......@@ -1550,21 +1567,23 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
void *workspace = buffers[10];
// tensor sizes
auto batch_size = descriptor.batch_size;
auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto num_heads = descriptor.num_heads;
auto attn_heads = descriptor.attn_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto bias_heads = descriptor.bias_heads;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
......@@ -1577,16 +1596,16 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto o_tensor = TensorWrapper(output, q_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
// prep RNG state
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
......@@ -1611,16 +1630,17 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
}
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
......@@ -1636,9 +1656,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
......@@ -1682,22 +1702,24 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
void *workspace = buffers[14];
// tensor sizes
auto batch_size = descriptor.batch_size;
auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto num_heads = descriptor.num_heads;
auto attn_heads = descriptor.attn_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto bias_heads = descriptor.bias_heads;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, num_heads, head_dim};
auto bias_shape = std::vector<size_t>{1, num_heads, q_max_seqlen, kv_max_seqlen};
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
// input tensors
auto dtype = descriptor.dtype;
......@@ -1714,9 +1736,9 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch_size + 1}, DType::kInt32);
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
// auxiliary tensors (propagated from the forward pass)
NVTETensorPack aux_input_tensors;
......@@ -1724,7 +1746,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias);
......
......@@ -105,11 +105,13 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t paddin
DType dtype, float scale_factor);
struct CustomCallFusedAttnDescriptor {
size_t batch_size;
size_t input_batch;
size_t bias_batch;
size_t q_max_seqlen;
size_t kv_max_seqlen;
size_t num_heads;
size_t attn_heads;
size_t num_gqa_groups;
size_t bias_heads;
size_t head_dim;
size_t wkspace_size;
float scaling_factor;
......@@ -122,10 +124,11 @@ struct CustomCallFusedAttnDescriptor {
};
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
DType wkspace_dtype, bool is_training);
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t wkspace_size, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
DType dtype, DType wkspace_dtype, bool is_training);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......@@ -205,47 +208,53 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
std::size_t opaque_len);
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
bool is_training);
size_t input_batch, size_t bias_batch, size_t max_seqlen,
size_t attn_heads, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype,
bool is_training);
size_t input_batch, size_t bias_batch, size_t max_seqlen,
size_t attn_heads, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t num_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
......@@ -2784,10 +2784,15 @@ class DotProductAttention(torch.nn.Module):
_, fu_core_attention_bias = get_alibi(
query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes,
bias_dtype=query_layer.dtype)
if (fu_core_attention_bias.shape[0] != 1
or fu_core_attention_bias.shape[1] != query_layer.shape[-2]):
# remove this line when cuDNN adds bwd support for [b, 1, s, s] and [b, h, s, s]
if (use_fused_attention
and fu_core_attention_bias_type == "post_scale_bias"
and (fu_core_attention_bias.shape[0] != 1
or fu_core_attention_bias.shape[1] != query_layer.shape[-2])):
if fu_core_attention_bias.requires_grad:
# remove this line when cuDNN adds bwd support for
# [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
use_fused_attention = False
else:
# max512 backend will only support [1, h, s, s]
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
......@@ -2812,6 +2817,11 @@ class DotProductAttention(torch.nn.Module):
use_fused_attention and is_backend_avail and \
(not context_parallel or \
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]))
if (fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]
and fu_core_attention_bias_type == "post_scale_bias"
and (fu_core_attention_bias.shape[0] != 1
or fu_core_attention_bias.shape[1] != query_layer.shape[-2])):
use_fused_attention = False
# Filter: determinism.
# backend | deterministic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment