Unverified Commit 87939be1 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[C/PyTorch] Add support for multi-latent attention (MLA) (#1039)



* add multi-latent attention for DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix Jax/Paddle API
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix typo in test script
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix too-many-boolean lint error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "fix lint"

This reverts commit 67399a3a6f45bb4ce9e5eaa6bcce40b28e347e5b.
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix stride check in get_qkv_layout
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: fix layout_thd tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix merge conflict
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix thd pad_between_seqs=False/True tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 27c6342e
......@@ -77,12 +77,13 @@ class ModelConfig:
batch_size: int,
num_heads: int,
num_gqa_groups: int,
head_dim: int,
head_dim_qk: int,
max_seqlen_q: int,
max_seqlen_kv: int,
dropout_p: float,
attn_mask_type: str,
attn_bias_type: str,
head_dim_v: int = None,
alibi_type: str = "none",
num_layers: int = 1,
bias_shape: str = "1hss",
......@@ -91,9 +92,10 @@ class ModelConfig:
self.batch_size = batch_size
self.num_heads = num_heads
self.num_gqa_groups = num_gqa_groups
self.head_dim = head_dim
self.hidden_size = num_heads * head_dim
self.hidden_size_kv = num_gqa_groups * head_dim
self.head_dim_qk = head_dim_qk
self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v
self.hidden_size = num_heads * head_dim_qk
self.hidden_size_kv = num_gqa_groups * self.head_dim_v
self.max_seqlen_q = max_seqlen_q
self.max_seqlen_kv = max_seqlen_kv
self.dropout_p = dropout_p
......@@ -137,7 +139,11 @@ def _get_attention_backends(
)
core_attention_bias_requires_grad = False
# d=256 is supported by cuDNN 9.0+ for inference but not training
if config.attn_bias_type == "post_scale_bias" and config.head_dim <= 128:
if (
config.attn_bias_type == "post_scale_bias"
and config.head_dim_qk <= 128
and config.head_dim_v <= 128
):
core_attention_bias_requires_grad = True
fused_attn_backends = []
......@@ -153,7 +159,8 @@ def _get_attention_backends(
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim=config.head_dim,
head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
......@@ -218,11 +225,12 @@ def test_dot_product_attention(
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
config = model_configs[model]
is_mla = config.head_dim_qk != config.head_dim_v
if qkv_layout is None:
if config.attn_type == "self":
qkv_layout = "sb3hd"
qkv_layout = "sb3hd" if not is_mla else "sbhd_sbhd_sbhd"
else:
qkv_layout = "sbhd_sb2hd"
qkv_layout = "bshd_bs2hd" if not is_mla else "bshd_bshd_bshd"
if "3" in qkv_layout and config.attn_type == "cross":
pytest.skip("No need to test this layout for cross attention")
......@@ -241,14 +249,17 @@ def test_dot_product_attention(
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
# mannually pads and unpads the input and output of FlashAttention for testing purposes
if pad_between_seqs:
if pad_between_seqs and not (
config.max_seqlen_q != config.max_seqlen_kv
and config.attn_mask_type in ["causal", "padding_causal"]
):
flash_attn_supported = True
# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
is_training = config.head_dim <= 128
is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128
# UnfusedDotProductAttention backend
if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
......@@ -343,6 +354,38 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
model_configs_mla = {
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0": ModelConfig(
8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias", head_dim_v=128
), # self , 0
"mla_1_1": ModelConfig(
4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_2_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64
), # self , 1
"mla_2_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # cross, 1
"mla_3_0": ModelConfig(
8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64
), # inference
"mla_3_1": ModelConfig(
8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_mla])
@pytest.mark.parametrize("model", model_configs_mla.keys())
def test_dpa_mla(dtype, model_configs, model):
"""Test DotProductAttention module with Multi-Latent Attention (MLA)"""
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
......@@ -586,14 +629,16 @@ model_configs_layout_thd = {
@pytest.mark.parametrize("qkv_layout", qkv_layouts_thd)
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
pad_between_seqs = False
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)
pad_between_seqs = True
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)
if get_cudnn_version() >= (9, 3, 0):
# cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
pad_between_seqs = False
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)
def _run_dot_product_attention(
......@@ -736,7 +781,8 @@ def _run_dot_product_attention(
"skv": config.max_seqlen_kv,
"h": config.num_heads,
"hg": config.num_gqa_groups,
"d": config.head_dim,
"dqk": config.head_dim_qk,
"dv": config.head_dim_v,
"t": cu_seqlens_q_after_pad[-1],
"tg": cu_seqlens_kv_after_pad[-1],
"3": 3,
......@@ -753,12 +799,16 @@ def _run_dot_product_attention(
layout = layout.replace("s", "skv")
layout = layout.replace("h", "hg")
layout = layout.replace("t", "tg")
if i == 2:
layout = layout.replace("d", "dv")
else:
layout = layout.replace("d", "dqk")
tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
tensor_orig = tensor
if qkv_format == "thd" and pad_between_seqs:
tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if layout in ["t_h_d", "t_3_h_d", "t_h_3_d"]:
if layout in ["t_h_dqk", "t_3_h_dqk", "t_h_3_dqk"]:
for i in range(1, config.batch_size + 1):
valid_range = (
cu_seqlens_q_after_pad[i - 1],
......@@ -772,7 +822,7 @@ def _run_dot_product_attention(
tensor_orig = torch.cat(
[tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
)
if layout in ["tg_hg_d", "tg_2_hg_d", "tg_hg_2_d"]:
if layout in ["tg_hg_dqk", "tg_2_hg_dqk", "tg_hg_2_dqk", "tg_hg_dv"]:
for i in range(1, config.batch_size + 1):
valid_range = (
cu_seqlens_kv_after_pad[i - 1],
......@@ -811,13 +861,14 @@ def _run_dot_product_attention(
# Create output gradient
qkv_format_kv = "_".join(qkv_format)
qkv_format_kv = qkv_format_kv.replace("s", "sq")
qkv_format_kv = qkv_format_kv.replace("d", "dv")
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
out_grad_orig = out_grad
if qkv_format == "thd" and pad_between_seqs:
out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if qkv_format_kv == "t_h_d":
if qkv_format_kv == "t_h_dv":
for i in range(1, config.batch_size + 1):
valid_range = (
cu_seqlens_q_after_pad[i - 1],
......@@ -851,7 +902,7 @@ def _run_dot_product_attention(
# Set up model
block = DotProductAttention(
config.num_heads,
config.head_dim,
config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
......@@ -906,9 +957,10 @@ def _run_dot_product_attention(
if backend == "FusedAttention":
if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if is_training:
q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
for i in range(1, config.batch_size + 1):
valid_range_q = (
cu_seqlens_q_after_pad[i - 1],
......@@ -919,15 +971,16 @@ def _run_dot_product_attention(
cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
)
out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0)
q_grad_orig = torch.cat(
[q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0
)
k_grad_orig = torch.cat(
[k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
v_grad_orig = torch.cat(
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
if is_training:
q_grad_orig = torch.cat(
[q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0
)
k_grad_orig = torch.cat(
[k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
v_grad_orig = torch.cat(
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
if is_training:
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig)
else:
......@@ -1168,7 +1221,7 @@ def _run_transformer_layer(
# Create RoPE
rotary_pos_emb = None
if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim)
PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
# Set up model
......@@ -1183,7 +1236,7 @@ def _run_transformer_layer(
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
kv_channels=config.head_dim,
kv_channels=config.head_dim_qk,
self_attn_mask_type=config.attn_mask_type,
tp_group=None,
tp_size=1,
......@@ -1356,7 +1409,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
mha = MultiheadAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_heads,
kv_channels=config.head_dim,
kv_channels=config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
layer_number=1,
......@@ -1387,7 +1440,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
"skv": config.max_seqlen_kv,
"h": config.num_heads,
"hg": config.num_gqa_groups,
"d": config.head_dim,
"d": config.head_dim_qk,
"t": cu_seqlens_q[-1],
"tg": cu_seqlens_kv[-1],
"3": 3,
......@@ -1531,7 +1584,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
with fp8_model_init(enabled=fp8_dpa):
dpa = DotProductAttention(
config.num_heads,
config.head_dim,
config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
sequence_parallel=False,
......@@ -1560,7 +1613,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
"skv": config.max_seqlen_kv,
"h": config.num_heads,
"hg": config.num_gqa_groups,
"d": config.head_dim,
"d": config.head_dim_qk,
"t": cu_seqlens_q[-1],
"tg": cu_seqlens_kv[-1],
"3": 3,
......@@ -1732,7 +1785,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
inp = 0.0001 * torch.randint(
-100,
100,
(config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim),
(config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim_qk),
dtype=dtype,
device="cuda",
requires_grad=True,
......@@ -1743,7 +1796,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
out_grad = 0.01 * torch.randn(
config.batch_size * config.max_seqlen_q,
config.num_heads * config.head_dim,
config.num_heads * config.head_dim_qk,
dtype=dtype,
device="cuda",
)
......@@ -1766,7 +1819,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
return (
out.view(config.batch_size, config.max_seqlen_q, -1),
dqkv.view(
config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim
config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk
).contiguous(),
)
......@@ -1809,7 +1862,7 @@ def _run_ref_mha_f16(dtype, config, backend):
block = DotProductAttention(
config.num_heads,
config.head_dim,
config.head_dim_qk,
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
......@@ -2105,7 +2158,7 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
self.p_dropout = config.dropout_p
self.h = config.num_heads
self.hidden_size = config.hidden_size
self.head_dim = config.head_dim
self.head_dim = config.head_dim_qk
self.fast_zero_fill = True
self.mask_type = config.attn_mask_type
......
......@@ -1083,7 +1083,7 @@ def test_export_core_attention(
model = te.attention.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=kv_channels,
k_channels=kv_channels,
attention_dropout=0.5,
qkv_format=qkv_format,
attn_mask_type=attn_mask_type,
......
......@@ -72,8 +72,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) {
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left,
int64_t window_size_right) {
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v,
int64_t window_size_left, int64_t window_size_right) {
using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device();
......@@ -84,10 +84,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
if (((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)) &&
(sm_arch_ >= 90) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) &&
(((cudnn_runtime_version >= 8900) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) &&
(max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim == 64) &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) ||
(max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim_qk == 64) &&
(head_dim_v == 64) && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) ||
((cudnn_runtime_version >= 90201) && (max_seqlen_q % 128 == 0) &&
(max_seqlen_kv % 128 == 0) && (head_dim == 128) &&
(max_seqlen_kv % 128 == 0) && (head_dim_qk == 128) && (head_dim_v == 128) &&
((qkv_format == NVTE_QKV_Format::NVTE_BSHD) ||
(qkv_format == NVTE_QKV_Format::NVTE_SBHD)) &&
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
......@@ -104,8 +104,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool flag_m512 = false;
bool flag_arb = false;
if ((sm_arch_ == 80 || sm_arch_ == 90) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) &&
(max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim == 64) &&
(num_attn_heads == num_gqa_groups) &&
(max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim_qk == 64) &&
(head_dim_v == 64) && (num_attn_heads == num_gqa_groups) &&
((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) &&
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
......@@ -131,11 +131,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) ||
(cudnn_runtime_version >= 8907)) &&
// head dimension
((head_dim <= 128 && head_dim % 8 == 0) ||
((head_dim_qk <= 128 && head_dim_qk % 8 == 0 && head_dim_v <= 128 && head_dim_v % 8 == 0) ||
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
// d=256 only supported for forward
(sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim <= 256 &&
head_dim % 8 == 0)) &&
(sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim_qk <= 256 &&
head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) &&
// bias type
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
((cudnn_runtime_version >= 8906) &&
......@@ -155,6 +155,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) ||
((cudnn_runtime_version >= 90300) &&
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) &&
max_seqlen_q <= max_seqlen_kv && dropout == 0.0)) &&
......@@ -259,7 +260,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
max_seqlen, d, window_size_left, window_size_right);
max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -336,7 +337,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
max_seqlen, d, window_size_left, window_size_right);
max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -430,7 +431,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, window_size_left, window_size_right);
max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -514,7 +515,7 @@ void nvte_fused_attn_bwd_kvpacked(
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, window_size_left, window_size_right);
max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -595,7 +596,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h_q = input_Q->data.shape[ndim - 2];
size_t h_kv = input_K->data.shape[ndim - 2];
size_t d = input_Q->data.shape[ndim - 1];
size_t d_qk = input_Q->data.shape[ndim - 1];
size_t d_v = input_V->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
......@@ -603,13 +605,13 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, window_size_left, window_size_right);
max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V,
input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K,
input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
......@@ -617,18 +619,18 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V,
input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q,
input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale,
fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V,
input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle);
......@@ -674,7 +676,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h_q = input_Q->data.shape[ndim - 2];
size_t h_kv = input_K->data.shape[ndim - 2];
size_t d = input_Q->data.shape[ndim - 1];
size_t d_qk = input_Q->data.shape[ndim - 1];
size_t d_v = input_V->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
......@@ -682,15 +685,15 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, window_size_left, window_size_right);
max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, input_Q, input_K, input_V, input_dO, output_S,
output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q,
input_cu_seqlens_kv, wkspace, stream, handle);
fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V,
input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
......@@ -705,9 +708,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
}
fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, input_K,
input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV,
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q,
input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV,
output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
......@@ -721,7 +724,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const Tensor *input_M = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout,
fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O,
input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ,
output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv,
......
......@@ -58,8 +58,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
......@@ -68,7 +68,7 @@ void fused_attn_arbitrary_seqlen_fwd(
void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
......
......@@ -1679,6 +1679,7 @@ void fused_attn_fp8_fwd_impl_v1(
s_q,
s_kv,
d,
d,
bias_b,
bias_h,
scaling_factor,
......@@ -1976,6 +1977,7 @@ void fused_attn_fp8_bwd_impl_v1(
s_q,
s_kv,
d,
d,
bias_b,
bias_h,
scaling_factor,
......
......@@ -363,29 +363,30 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu
// convert cu_seqlens_padded to offsets
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h,
size_t hg, size_t d, int32_t *cu_seqlens_q_padded,
size_t hg, size_t d_qk, size_t d_v,
int32_t *cu_seqlens_q_padded,
int32_t *cu_seqlens_kv_padded, int32_t *offsets_q,
int32_t *offsets_k, int32_t *offsets_v,
int32_t *offsets_o) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < b + 1) {
offsets_o[tid] = h * d * cu_seqlens_q_padded[tid];
offsets_o[tid] = h * d_v * cu_seqlens_q_padded[tid];
switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
offsets_q[tid] = h * d * cu_seqlens_q_padded[tid];
offsets_k[tid] = hg * d * cu_seqlens_kv_padded[tid];
offsets_v[tid] = offsets_k[tid];
offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[tid];
offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[tid];
offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[tid];
break;
case NVTE_QKV_Layout_Group::NVTE_3HD:
case NVTE_QKV_Layout_Group::NVTE_H3D:
offsets_q[tid] = 3 * h * d * cu_seqlens_q_padded[tid];
offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[tid];
offsets_k[tid] = offsets_q[tid];
offsets_v[tid] = offsets_q[tid];
break;
case NVTE_QKV_Layout_Group::NVTE_HD_2HD:
case NVTE_QKV_Layout_Group::NVTE_HD_H2D:
offsets_q[tid] = h * d * cu_seqlens_q_padded[tid];
offsets_k[tid] = 2 * hg * d * cu_seqlens_kv_padded[tid];
offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[tid];
offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[tid];
offsets_v[tid] = offsets_k[tid];
break;
}
......
......@@ -91,7 +91,8 @@ struct FADescriptor_v1 {
std::int64_t hg;
std::int64_t s_q;
std::int64_t s_kv;
std::int64_t d;
std::int64_t d_qk;
std::int64_t d_v;
std::int64_t bias_b;
std::int64_t bias_h;
float attnScale;
......@@ -107,11 +108,11 @@ struct FADescriptor_v1 {
cudnn_frontend::DataType_t bwd_tensor_type;
bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h, attnScale, isTraining,
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, bias_b, bias_h, attnScale, isTraining,
dropoutProbability, layout, mask_type, window_size_left, window_size_right,
deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d, rhs.bias_b, rhs.bias_h,
rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout,
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.bias_b,
rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic,
rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type);
}
......@@ -126,7 +127,8 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu
int32_t *kv_seqlens);
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h,
size_t hg, size_t d, int32_t *cu_seqlens_q_padded,
size_t hg, size_t d_qk, size_t d_v,
int32_t *cu_seqlens_q_padded,
int32_t *cu_seqlens_kv_padded, int32_t *offsets_q,
int32_t *offsets_k, int32_t *offsets_v,
int32_t *offsets_o);
......
......@@ -147,15 +147,16 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] num_gqa_groups The number of heads in K, V.
* \param[in] max_seqlen_q The sequence length of Q.
* \param[in] max_seqlen_kv The sequence length of K, V.
* \param[in] head_dim The head dimension of Q, K, V.
* \param[in] head_dim_qk The head dimension of Q, K.
* \param[in] head_dim_v The head dimension of V.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
*/
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left,
int64_t window_size_right);
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v,
int64_t window_size_left, int64_t window_size_right);
/*! \brief Compute dot product attention with packed QKV input.
*
......
......@@ -19,7 +19,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen,
head_dim, -1, -1);
head_dim, head_dim, -1, -1);
return backend;
}
......@@ -255,10 +255,10 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
/* Prepare RNG state */
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype),
qkv_layout, bias_type, mask_type, dropout_probability, attn_heads,
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, -1, -1);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, -1, -1);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */
......@@ -486,10 +486,10 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
/* Auxiliary tensors (propagated from the forward pass) */
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
auto backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype),
qkv_layout, bias_type, mask_type, dropout_probability, attn_heads,
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, -1, -1);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, -1, -1);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias);
......
......@@ -131,10 +131,10 @@ inline NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim) {
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype),
qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads,
num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, -1, -1);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv,
head_dim, head_dim, -1, -1);
return fused_attention_backend;
}
......
......@@ -142,8 +142,10 @@ class AttentionParams:
Maximum sequence length of the query tensor.
max_seqlen_kv: int, default = 128
Maximum sequence length of the key and value tensors.
head_dim: int, default = 64
The size of each attention head.
head_dim_qk: int, default = 64
The size of each attention head in query and key tensors.
head_dim_v: int, default = 64
The size of each attention head in the value tensor.
attn_mask_type: str, default = `no_mask`
Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
`causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
......@@ -182,7 +184,8 @@ class AttentionParams:
num_gqa_groups: int = 16
max_seqlen_q: int = 128
max_seqlen_kv: int = 128
head_dim: int = 64
head_dim_qk: int = 64
head_dim_v: int = 64
attn_mask_type: str = "no_mask"
window_size: Union[Tuple[int, int], None] = None
alibi_slopes_shape: Union[torch.Size, List, None] = None
......@@ -245,7 +248,8 @@ def get_attention_backend(
num_gqa_groups = attention_params.num_gqa_groups
max_seqlen_q = attention_params.max_seqlen_q
max_seqlen_kv = attention_params.max_seqlen_kv
head_dim = attention_params.head_dim
head_dim_qk = attention_params.head_dim_qk
head_dim_v = attention_params.head_dim_v
attn_mask_type = attention_params.attn_mask_type
window_size = attention_params.window_size
alibi_slopes_shape = attention_params.alibi_slopes_shape
......@@ -352,19 +356,31 @@ def get_attention_backend(
use_unfused_attention = False
# Filter: Head dimension
if use_flash_attention and head_dim_qk != head_dim_v:
logger.debug("Disabling FlashAttention as it does not support MLA.")
use_flash_attention = False
if use_flash_attention and (
head_dim > 256
or head_dim % 8 != 0
or (head_dim > 192 and device_compute_capability not in ((8, 0), (9, 0)))
head_dim_qk > 256
or head_dim_qk % 8 != 0
or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0)))
):
logger.debug(
"Disabling FlashAttention due to unsupported head_dim. "
"Supported: head_dim %%8 = 0, head_dim <= 256 (>192 requires sm80/90). "
"Found: head_dim = %s on sm%s.",
head_dim,
"Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. "
"Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
"head_dim_qk <= 256 (>192 requires sm80/90). "
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
head_dim_qk,
head_dim_v,
".".join([str(i) for i in device_compute_capability]),
)
use_flash_attention = False
qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd":
logger.debug(
"Disabling FusedAttention as MLA is not supported with qkv_layout = %s",
qkv_layout,
)
use_fused_attention = False
# Filter: QKV layout
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
......@@ -557,7 +573,8 @@ def get_attention_backend(
num_gqa_groups,
max_seqlen_q,
max_seqlen_kv,
head_dim,
head_dim_qk,
head_dim_v,
window_size[0],
window_size[1],
)
......@@ -3132,12 +3149,14 @@ def get_qkv_layout(
stride = q.stride()
check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
stride = k.stride()
check_strides_kv = all(stride == x.stride() for x in [k, v])
check_strides_kv = torch.equal(
torch.Tensor(stride[:-1]) / k.shape[-1], torch.Tensor(v.stride()[:-1]) / v.shape[-1]
)
shape = q.shape
check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
shape = k.shape
check_shapes_kv = all(shape == x.shape for x in [k, v])
check_shapes_kv = shape[:-1] == v.shape[:-1]
last_dim_size = q.shape[-1]
check_last_dim_offsets_qkv = all(
......@@ -5177,8 +5196,10 @@ class DotProductAttention(TransformerEngineBaseModule):
----------
num_attention_heads : int
number of attention heads in the transformer layer.
kv_channels : int
number of key-query-value channels per attention head.
k_channels : int
number of channels per attention head in key.
v_channels : Optional[int] = None
number of channels per attention head in value.
num_gqa_groups : Optional[int] = None
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
......@@ -5264,7 +5285,8 @@ class DotProductAttention(TransformerEngineBaseModule):
def __init__(
self,
num_attention_heads: int,
kv_channels: int,
k_channels: int,
v_channels: Optional[int] = None,
num_gqa_groups: Optional[int] = None,
attention_dropout: float = 0.0,
qkv_format: str = "sbhd",
......@@ -5304,7 +5326,8 @@ class DotProductAttention(TransformerEngineBaseModule):
self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream
self.hidden_size_per_attention_head = kv_channels
self.hidden_size_per_attention_head = k_channels
self.v_channels = k_channels if v_channels is None else v_channels
self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
......@@ -5322,7 +5345,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attention_dropout_ctx = self.rng_states_tracker.fork
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(kv_channels)
softmax_scale = 1.0 / math.sqrt(k_channels)
self.deterministic = (
not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
......@@ -5469,16 +5492,6 @@ class DotProductAttention(TransformerEngineBaseModule):
Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
includes '"padding"' or `"arbitrary"`.
.. note::
Input tensor :attr:`query_layer` must be of shape
(:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`,
:attr:`kv_channels`) and the tensors :attr:`key_layer` and :attr:`value_layer`
must each be of shape (:attr:`sequence_length`, :attr:`batch_size`,
:attr:`num_gqa_groups`, :attr:`kv_channels`). Output of shape
(:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`
* :attr:`kv_channels`) is returned.
.. note::
DotProductAttention supports three backends: 1) FlashAttention which calls
......@@ -5628,7 +5641,9 @@ class DotProductAttention(TransformerEngineBaseModule):
assert (
query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype
), "Queries, keys and values must have the same data type!"
assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!"
assert (
key_layer.shape[:-1] == value_layer.shape[:-1]
), "Keys and values must have the same batch size, sequence length and number of heads!"
if attn_mask_type is None:
attn_mask_type = self.attn_mask_type
......@@ -5861,7 +5876,8 @@ class DotProductAttention(TransformerEngineBaseModule):
num_gqa_groups=key_layer.shape[-2],
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
head_dim=query_layer.shape[-1],
head_dim_qk=query_layer.shape[-1],
head_dim_v=value_layer.shape[-1],
attn_mask_type=attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
......
......@@ -140,7 +140,7 @@ def fused_attn_fwd_qkvpacked(
output tensor, amax of O, used by the next iteration in FP8 computations
attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default
if None, use 1.0/sqrt(head_dim_qk) as the default
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
......@@ -342,7 +342,7 @@ def fused_attn_bwd_qkvpacked(
output tensor, amax of dQKV, used by the next iteration in FP8 computations
attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default
if None, use 1.0/sqrt(head_dim_qk) as the default
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
......@@ -508,7 +508,7 @@ def fused_attn_fwd_kvpacked(
output tensor, amax of O, used by the next iteration in FP8 computations
attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default
if None, use 1.0/sqrt(head_dim_qk) as the default
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
......@@ -729,7 +729,7 @@ def fused_attn_bwd_kvpacked(
output tensor, amax of dQKV, used by the next iteration in FP8 computations
attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default
if None, use 1.0/sqrt(head_dim_qk) as the default
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
......@@ -907,7 +907,7 @@ def fused_attn_fwd(
output tensor, amax of O, used by the next iteration in FP8 computations
attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default
if None, use 1.0/sqrt(head_dim_qk) as the default
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
......@@ -1135,7 +1135,7 @@ def fused_attn_bwd(
output tensor, amax of dQ, dK and dV, used by the next iteration in FP8 computations
attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default
if None, use 1.0/sqrt(head_dim_qk) as the default
dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
......
......@@ -14,11 +14,14 @@
* Attention
**************************************************************************************************/
NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, int64_t window_size_right);
NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype,
const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float p_dropout,
size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv,
size_t head_dim_qk, size_t head_dim_v,
int64_t window_size_left, int64_t window_size_right);
std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero,
......
......@@ -14,11 +14,12 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, int64_t window_size_right) {
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) {
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv,
head_dim, window_size_left, window_size_right);
head_dim_qk, head_dim_v, window_size_left, window_size_right);
return fused_attention_backend;
}
......@@ -761,7 +762,11 @@ std::vector<at::Tensor> fused_attn_fwd(
std::vector<size_t> v_shape{v_sizes.begin(), v_sizes.end()};
// create output tensor O
auto O = torch::empty_like(Q);
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
auto o_shape = std::vector<size_t>{q_sizes.begin(), q_sizes.end()};
o_shape[o_shape.size() - 1] = v_sizes[v_sizes.size() - 1];
std::vector<int64_t> o_shape_tmp{o_shape.begin(), o_shape.end()};
auto O = torch::empty(c10::IntArrayRef(o_shape_tmp), options);
// construct NVTE tensors
TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias;
......@@ -790,7 +795,7 @@ std::vector<at::Tensor> fused_attn_fwd(
descale_QKV.value().data_ptr());
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(),
scale_S.value().data_ptr(), descale_S.value().data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(),
te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, amax_O.value().data_ptr(),
scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
......@@ -801,7 +806,7 @@ std::vector<at::Tensor> fused_attn_fwd(
te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr);
te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, nullptr);
} else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
}
......@@ -839,8 +844,7 @@ std::vector<at::Tensor> fused_attn_fwd(
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
auto options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
philox_args, static_cast<int64_t *>(rng_state.data_ptr()));
auto te_rng_state = makeTransformerEngineTensor(rng_state);
......@@ -935,8 +939,11 @@ std::vector<at::Tensor> fused_attn_bwd(
std::vector<size_t> v_shape{v_sizes.begin(), v_sizes.end()};
auto h_q = q_shape[q_shape.size() - 2];
auto h_kv = k_shape[k_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1];
auto d_qk = q_shape[q_shape.size() - 1];
auto d_v = v_shape[v_shape.size() - 1];
auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA);
std::vector<size_t> o_shape{q_sizes.begin(), q_sizes.end()};
o_shape[o_shape.size() - 1] = d_v;
at::Tensor dQ;
at::Tensor dK;
......@@ -1015,7 +1022,7 @@ std::vector<at::Tensor> fused_attn_bwd(
TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
if (set_zero && ((h_q * d) % block_size == 0) && ((h_kv * d) % block_size == 0) &&
if (set_zero && ((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) &&
dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous() &&
(nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
......@@ -1041,9 +1048,9 @@ std::vector<at::Tensor> fused_attn_bwd(
descale_QKV.value().data_ptr());
te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr,
descale_QKV.value().data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr,
te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr,
descale_O.value().data_ptr());
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr,
te_dO = makeTransformerEngineTensor(dO.data_ptr(), o_shape, dqkv_type, nullptr, nullptr,
descale_dO.value().data_ptr());
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr,
scale_S.value().data_ptr(), descale_S.value().data_ptr());
......@@ -1068,9 +1075,9 @@ std::vector<at::Tensor> fused_attn_bwd(
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr);
te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr);
te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, nullptr);
te_dO =
makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr);
makeTransformerEngineTensor(dO.data_ptr(), o_shape, dqkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr);
te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr);
te_dQ =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment