"tests/cpp_tests/test_chunked_array.cpp" did not exist on "971b5486873339bcaa2bbe0948c95e1fff7246c7"
Unverified Commit 87939be1 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[C/PyTorch] Add support for multi-latent attention (MLA) (#1039)



* add multi-latent attention for DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix Jax/Paddle API
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix typo in test script
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix too-many-boolean lint error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "fix lint"

This reverts commit 67399a3a6f45bb4ce9e5eaa6bcce40b28e347e5b.
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix stride check in get_qkv_layout
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: fix layout_thd tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix merge conflict
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix thd pad_between_seqs=False/True tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 27c6342e
...@@ -77,12 +77,13 @@ class ModelConfig: ...@@ -77,12 +77,13 @@ class ModelConfig:
batch_size: int, batch_size: int,
num_heads: int, num_heads: int,
num_gqa_groups: int, num_gqa_groups: int,
head_dim: int, head_dim_qk: int,
max_seqlen_q: int, max_seqlen_q: int,
max_seqlen_kv: int, max_seqlen_kv: int,
dropout_p: float, dropout_p: float,
attn_mask_type: str, attn_mask_type: str,
attn_bias_type: str, attn_bias_type: str,
head_dim_v: int = None,
alibi_type: str = "none", alibi_type: str = "none",
num_layers: int = 1, num_layers: int = 1,
bias_shape: str = "1hss", bias_shape: str = "1hss",
...@@ -91,9 +92,10 @@ class ModelConfig: ...@@ -91,9 +92,10 @@ class ModelConfig:
self.batch_size = batch_size self.batch_size = batch_size
self.num_heads = num_heads self.num_heads = num_heads
self.num_gqa_groups = num_gqa_groups self.num_gqa_groups = num_gqa_groups
self.head_dim = head_dim self.head_dim_qk = head_dim_qk
self.hidden_size = num_heads * head_dim self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v
self.hidden_size_kv = num_gqa_groups * head_dim self.hidden_size = num_heads * head_dim_qk
self.hidden_size_kv = num_gqa_groups * self.head_dim_v
self.max_seqlen_q = max_seqlen_q self.max_seqlen_q = max_seqlen_q
self.max_seqlen_kv = max_seqlen_kv self.max_seqlen_kv = max_seqlen_kv
self.dropout_p = dropout_p self.dropout_p = dropout_p
...@@ -137,7 +139,11 @@ def _get_attention_backends( ...@@ -137,7 +139,11 @@ def _get_attention_backends(
) )
core_attention_bias_requires_grad = False core_attention_bias_requires_grad = False
# d=256 is supported by cuDNN 9.0+ for inference but not training # d=256 is supported by cuDNN 9.0+ for inference but not training
if config.attn_bias_type == "post_scale_bias" and config.head_dim <= 128: if (
config.attn_bias_type == "post_scale_bias"
and config.head_dim_qk <= 128
and config.head_dim_v <= 128
):
core_attention_bias_requires_grad = True core_attention_bias_requires_grad = True
fused_attn_backends = [] fused_attn_backends = []
...@@ -153,7 +159,8 @@ def _get_attention_backends( ...@@ -153,7 +159,8 @@ def _get_attention_backends(
num_gqa_groups=config.num_gqa_groups, num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q, max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv, max_seqlen_kv=config.max_seqlen_kv,
head_dim=config.head_dim, head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type, attn_mask_type=config.attn_mask_type,
window_size=window_size, window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape, alibi_slopes_shape=alibi_slopes_shape,
...@@ -218,11 +225,12 @@ def test_dot_product_attention( ...@@ -218,11 +225,12 @@ def test_dot_product_attention(
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2) tols = dict(atol=2.5e-2, rtol=2.5e-2)
config = model_configs[model] config = model_configs[model]
is_mla = config.head_dim_qk != config.head_dim_v
if qkv_layout is None: if qkv_layout is None:
if config.attn_type == "self": if config.attn_type == "self":
qkv_layout = "sb3hd" qkv_layout = "sb3hd" if not is_mla else "sbhd_sbhd_sbhd"
else: else:
qkv_layout = "sbhd_sb2hd" qkv_layout = "bshd_bs2hd" if not is_mla else "bshd_bshd_bshd"
if "3" in qkv_layout and config.attn_type == "cross": if "3" in qkv_layout and config.attn_type == "cross":
pytest.skip("No need to test this layout for cross attention") pytest.skip("No need to test this layout for cross attention")
...@@ -241,14 +249,17 @@ def test_dot_product_attention( ...@@ -241,14 +249,17 @@ def test_dot_product_attention(
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
# mannually pads and unpads the input and output of FlashAttention for testing purposes # mannually pads and unpads the input and output of FlashAttention for testing purposes
if pad_between_seqs: if pad_between_seqs and not (
config.max_seqlen_q != config.max_seqlen_kv
and config.attn_mask_type in ["causal", "padding_causal"]
):
flash_attn_supported = True flash_attn_supported = True
# Skip if only unfused backend is supported # Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.") pytest.skip("Less than two backends to compare.")
is_training = config.head_dim <= 128 is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128
# UnfusedDotProductAttention backend # UnfusedDotProductAttention backend
if unfused_attn_supported: if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
...@@ -343,6 +354,38 @@ def test_dpa_checkpoint(dtype, model_configs, model): ...@@ -343,6 +354,38 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
model_configs_mla = {
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0": ModelConfig(
8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias", head_dim_v=128
), # self , 0
"mla_1_1": ModelConfig(
4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_2_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64
), # self , 1
"mla_2_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # cross, 1
"mla_3_0": ModelConfig(
8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64
), # inference
"mla_3_1": ModelConfig(
8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_mla])
@pytest.mark.parametrize("model", model_configs_mla.keys())
def test_dpa_mla(dtype, model_configs, model):
"""Test DotProductAttention module with Multi-Latent Attention (MLA)"""
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
model_configs_mask = { model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), "mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
...@@ -586,11 +629,13 @@ model_configs_layout_thd = { ...@@ -586,11 +629,13 @@ model_configs_layout_thd = {
@pytest.mark.parametrize("qkv_layout", qkv_layouts_thd) @pytest.mark.parametrize("qkv_layout", qkv_layouts_thd)
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout): def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts""" """Test DotProductAttention module with different QKV layouts"""
pad_between_seqs = False pad_between_seqs = True
test_dot_product_attention( test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
) )
pad_between_seqs = True if get_cudnn_version() >= (9, 3, 0):
# cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
pad_between_seqs = False
test_dot_product_attention( test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
) )
...@@ -736,7 +781,8 @@ def _run_dot_product_attention( ...@@ -736,7 +781,8 @@ def _run_dot_product_attention(
"skv": config.max_seqlen_kv, "skv": config.max_seqlen_kv,
"h": config.num_heads, "h": config.num_heads,
"hg": config.num_gqa_groups, "hg": config.num_gqa_groups,
"d": config.head_dim, "dqk": config.head_dim_qk,
"dv": config.head_dim_v,
"t": cu_seqlens_q_after_pad[-1], "t": cu_seqlens_q_after_pad[-1],
"tg": cu_seqlens_kv_after_pad[-1], "tg": cu_seqlens_kv_after_pad[-1],
"3": 3, "3": 3,
...@@ -753,12 +799,16 @@ def _run_dot_product_attention( ...@@ -753,12 +799,16 @@ def _run_dot_product_attention(
layout = layout.replace("s", "skv") layout = layout.replace("s", "skv")
layout = layout.replace("h", "hg") layout = layout.replace("h", "hg")
layout = layout.replace("t", "tg") layout = layout.replace("t", "tg")
if i == 2:
layout = layout.replace("d", "dv")
else:
layout = layout.replace("d", "dqk")
tensor_shape = [dim_to_num[j] for j in layout.split("_")] tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda") tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
tensor_orig = tensor tensor_orig = tensor
if qkv_format == "thd" and pad_between_seqs: if qkv_format == "thd" and pad_between_seqs:
tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if layout in ["t_h_d", "t_3_h_d", "t_h_3_d"]: if layout in ["t_h_dqk", "t_3_h_dqk", "t_h_3_dqk"]:
for i in range(1, config.batch_size + 1): for i in range(1, config.batch_size + 1):
valid_range = ( valid_range = (
cu_seqlens_q_after_pad[i - 1], cu_seqlens_q_after_pad[i - 1],
...@@ -772,7 +822,7 @@ def _run_dot_product_attention( ...@@ -772,7 +822,7 @@ def _run_dot_product_attention(
tensor_orig = torch.cat( tensor_orig = torch.cat(
[tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0 [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
) )
if layout in ["tg_hg_d", "tg_2_hg_d", "tg_hg_2_d"]: if layout in ["tg_hg_dqk", "tg_2_hg_dqk", "tg_hg_2_dqk", "tg_hg_dv"]:
for i in range(1, config.batch_size + 1): for i in range(1, config.batch_size + 1):
valid_range = ( valid_range = (
cu_seqlens_kv_after_pad[i - 1], cu_seqlens_kv_after_pad[i - 1],
...@@ -811,13 +861,14 @@ def _run_dot_product_attention( ...@@ -811,13 +861,14 @@ def _run_dot_product_attention(
# Create output gradient # Create output gradient
qkv_format_kv = "_".join(qkv_format) qkv_format_kv = "_".join(qkv_format)
qkv_format_kv = qkv_format_kv.replace("s", "sq") qkv_format_kv = qkv_format_kv.replace("s", "sq")
qkv_format_kv = qkv_format_kv.replace("d", "dv")
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")] out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda") out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
out_grad_orig = out_grad out_grad_orig = out_grad
if qkv_format == "thd" and pad_between_seqs: if qkv_format == "thd" and pad_between_seqs:
out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if qkv_format_kv == "t_h_d": if qkv_format_kv == "t_h_dv":
for i in range(1, config.batch_size + 1): for i in range(1, config.batch_size + 1):
valid_range = ( valid_range = (
cu_seqlens_q_after_pad[i - 1], cu_seqlens_q_after_pad[i - 1],
...@@ -851,7 +902,7 @@ def _run_dot_product_attention( ...@@ -851,7 +902,7 @@ def _run_dot_product_attention(
# Set up model # Set up model
block = DotProductAttention( block = DotProductAttention(
config.num_heads, config.num_heads,
config.head_dim, config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups, num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p, attention_dropout=config.dropout_p,
qkv_format=qkv_format, qkv_format=qkv_format,
...@@ -906,6 +957,7 @@ def _run_dot_product_attention( ...@@ -906,6 +957,7 @@ def _run_dot_product_attention(
if backend == "FusedAttention": if backend == "FusedAttention":
if qkv_format == "thd" and pad_between_seqs: if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if is_training:
q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
...@@ -919,6 +971,7 @@ def _run_dot_product_attention( ...@@ -919,6 +971,7 @@ def _run_dot_product_attention(
cu_seqlens_kv_after_pad[i] - pad_len[i - 1], cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
) )
out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0) out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0)
if is_training:
q_grad_orig = torch.cat( q_grad_orig = torch.cat(
[q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0 [q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0
) )
...@@ -1168,7 +1221,7 @@ def _run_transformer_layer( ...@@ -1168,7 +1221,7 @@ def _run_transformer_layer(
# Create RoPE # Create RoPE
rotary_pos_emb = None rotary_pos_emb = None
if RoPE: if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim) PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda") rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
# Set up model # Set up model
...@@ -1183,7 +1236,7 @@ def _run_transformer_layer( ...@@ -1183,7 +1236,7 @@ def _run_transformer_layer(
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
layer_number=layer_number, layer_number=layer_number,
kv_channels=config.head_dim, kv_channels=config.head_dim_qk,
self_attn_mask_type=config.attn_mask_type, self_attn_mask_type=config.attn_mask_type,
tp_group=None, tp_group=None,
tp_size=1, tp_size=1,
...@@ -1356,7 +1409,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ...@@ -1356,7 +1409,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
mha = MultiheadAttention( mha = MultiheadAttention(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
num_attention_heads=config.num_heads, num_attention_heads=config.num_heads,
kv_channels=config.head_dim, kv_channels=config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups, num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p, attention_dropout=config.dropout_p,
layer_number=1, layer_number=1,
...@@ -1387,7 +1440,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ...@@ -1387,7 +1440,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
"skv": config.max_seqlen_kv, "skv": config.max_seqlen_kv,
"h": config.num_heads, "h": config.num_heads,
"hg": config.num_gqa_groups, "hg": config.num_gqa_groups,
"d": config.head_dim, "d": config.head_dim_qk,
"t": cu_seqlens_q[-1], "t": cu_seqlens_q[-1],
"tg": cu_seqlens_kv[-1], "tg": cu_seqlens_kv[-1],
"3": 3, "3": 3,
...@@ -1531,7 +1584,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): ...@@ -1531,7 +1584,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
with fp8_model_init(enabled=fp8_dpa): with fp8_model_init(enabled=fp8_dpa):
dpa = DotProductAttention( dpa = DotProductAttention(
config.num_heads, config.num_heads,
config.head_dim, config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups, num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p, attention_dropout=config.dropout_p,
sequence_parallel=False, sequence_parallel=False,
...@@ -1560,7 +1613,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): ...@@ -1560,7 +1613,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
"skv": config.max_seqlen_kv, "skv": config.max_seqlen_kv,
"h": config.num_heads, "h": config.num_heads,
"hg": config.num_gqa_groups, "hg": config.num_gqa_groups,
"d": config.head_dim, "d": config.head_dim_qk,
"t": cu_seqlens_q[-1], "t": cu_seqlens_q[-1],
"tg": cu_seqlens_kv[-1], "tg": cu_seqlens_kv[-1],
"3": 3, "3": 3,
...@@ -1732,7 +1785,7 @@ def _run_custom_mha_fp8(dtype, config, backend): ...@@ -1732,7 +1785,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
inp = 0.0001 * torch.randint( inp = 0.0001 * torch.randint(
-100, -100,
100, 100,
(config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim), (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim_qk),
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
...@@ -1743,7 +1796,7 @@ def _run_custom_mha_fp8(dtype, config, backend): ...@@ -1743,7 +1796,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
out_grad = 0.01 * torch.randn( out_grad = 0.01 * torch.randn(
config.batch_size * config.max_seqlen_q, config.batch_size * config.max_seqlen_q,
config.num_heads * config.head_dim, config.num_heads * config.head_dim_qk,
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
) )
...@@ -1766,7 +1819,7 @@ def _run_custom_mha_fp8(dtype, config, backend): ...@@ -1766,7 +1819,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
return ( return (
out.view(config.batch_size, config.max_seqlen_q, -1), out.view(config.batch_size, config.max_seqlen_q, -1),
dqkv.view( dqkv.view(
config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk
).contiguous(), ).contiguous(),
) )
...@@ -1809,7 +1862,7 @@ def _run_ref_mha_f16(dtype, config, backend): ...@@ -1809,7 +1862,7 @@ def _run_ref_mha_f16(dtype, config, backend):
block = DotProductAttention( block = DotProductAttention(
config.num_heads, config.num_heads,
config.head_dim, config.head_dim_qk,
attention_dropout=config.dropout_p, attention_dropout=config.dropout_p,
sequence_parallel=False, sequence_parallel=False,
tp_size=1, tp_size=1,
...@@ -2105,7 +2158,7 @@ class Custom_MHA_FP8(TransformerEngineBaseModule): ...@@ -2105,7 +2158,7 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
self.p_dropout = config.dropout_p self.p_dropout = config.dropout_p
self.h = config.num_heads self.h = config.num_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_dim = config.head_dim self.head_dim = config.head_dim_qk
self.fast_zero_fill = True self.fast_zero_fill = True
self.mask_type = config.attn_mask_type self.mask_type = config.attn_mask_type
......
...@@ -1083,7 +1083,7 @@ def test_export_core_attention( ...@@ -1083,7 +1083,7 @@ def test_export_core_attention(
model = te.attention.DotProductAttention( model = te.attention.DotProductAttention(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
kv_channels=kv_channels, k_channels=kv_channels,
attention_dropout=0.5, attention_dropout=0.5,
qkv_format=qkv_format, qkv_format=qkv_format,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
......
...@@ -72,8 +72,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { ...@@ -72,8 +72,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) {
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v,
int64_t window_size_right) { int64_t window_size_left, int64_t window_size_right) {
using namespace transformer_engine; using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device(); const int device_id = cuda::current_device();
...@@ -84,10 +84,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -84,10 +84,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
if (((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)) && if (((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)) &&
(sm_arch_ >= 90) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && (sm_arch_ >= 90) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) &&
(((cudnn_runtime_version >= 8900) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) && (((cudnn_runtime_version >= 8900) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) &&
(max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim == 64) && (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim_qk == 64) &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) || (head_dim_v == 64) && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) ||
((cudnn_runtime_version >= 90201) && (max_seqlen_q % 128 == 0) && ((cudnn_runtime_version >= 90201) && (max_seqlen_q % 128 == 0) &&
(max_seqlen_kv % 128 == 0) && (head_dim == 128) && (max_seqlen_kv % 128 == 0) && (head_dim_qk == 128) && (head_dim_v == 128) &&
((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) ||
(qkv_format == NVTE_QKV_Format::NVTE_SBHD)) && (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) &&
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
...@@ -104,8 +104,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -104,8 +104,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool flag_m512 = false; bool flag_m512 = false;
bool flag_arb = false; bool flag_arb = false;
if ((sm_arch_ == 80 || sm_arch_ == 90) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) && if ((sm_arch_ == 80 || sm_arch_ == 90) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) &&
(max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim == 64) && (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim_qk == 64) &&
(num_attn_heads == num_gqa_groups) && (head_dim_v == 64) && (num_attn_heads == num_gqa_groups) &&
((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) &&
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
...@@ -131,11 +131,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -131,11 +131,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) ||
(cudnn_runtime_version >= 8907)) && (cudnn_runtime_version >= 8907)) &&
// head dimension // head dimension
((head_dim <= 128 && head_dim % 8 == 0) || ((head_dim_qk <= 128 && head_dim_qk % 8 == 0 && head_dim_v <= 128 && head_dim_v % 8 == 0) ||
// TODO (cyang): add is_training to nvte_get_fused_attn_backend // TODO (cyang): add is_training to nvte_get_fused_attn_backend
// d=256 only supported for forward // d=256 only supported for forward
(sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim <= 256 && (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim_qk <= 256 &&
head_dim % 8 == 0)) && head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) &&
// bias type // bias type
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
((cudnn_runtime_version >= 8906) && ((cudnn_runtime_version >= 8906) &&
...@@ -155,6 +155,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -155,6 +155,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) ||
((cudnn_runtime_version >= 90300) && ((cudnn_runtime_version >= 90300) &&
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) &&
max_seqlen_q <= max_seqlen_kv && dropout == 0.0)) && max_seqlen_q <= max_seqlen_kv && dropout == 0.0)) &&
...@@ -259,7 +260,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -259,7 +260,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
max_seqlen, d, window_size_left, window_size_right); max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -336,7 +337,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -336,7 +337,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
max_seqlen, d, window_size_left, window_size_right); max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -430,7 +431,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const ...@@ -430,7 +431,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, window_size_left, window_size_right); max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -514,7 +515,7 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -514,7 +515,7 @@ void nvte_fused_attn_bwd_kvpacked(
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, window_size_left, window_size_right); max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -595,7 +596,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -595,7 +596,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h_q = input_Q->data.shape[ndim - 2]; size_t h_q = input_Q->data.shape[ndim - 2];
size_t h_kv = input_K->data.shape[ndim - 2]; size_t h_kv = input_K->data.shape[ndim - 2];
size_t d = input_Q->data.shape[ndim - 1]; size_t d_qk = input_Q->data.shape[ndim - 1];
size_t d_v = input_V->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype); const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
...@@ -603,13 +605,13 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -603,13 +605,13 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, window_size_left, window_size_right); max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale,
qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K,
input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle);
#else #else
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
...@@ -617,18 +619,18 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -617,18 +619,18 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout,
bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q,
input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state,
handle); wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V,
input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle);
...@@ -674,7 +676,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -674,7 +676,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h_q = input_Q->data.shape[ndim - 2]; size_t h_q = input_Q->data.shape[ndim - 2];
size_t h_kv = input_K->data.shape[ndim - 2]; size_t h_kv = input_K->data.shape[ndim - 2];
size_t d = input_Q->data.shape[ndim - 1]; size_t d_qk = input_Q->data.shape[ndim - 1];
size_t d_v = input_V->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype); const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
...@@ -682,15 +685,15 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -682,15 +685,15 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, window_size_left, window_size_right); max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout,
bias_type, attn_mask_type, input_Q, input_K, input_V, input_dO, output_S, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V,
output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias,
input_cu_seqlens_kv, wkspace, stream, handle); input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle);
#else #else
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif #endif
...@@ -705,9 +708,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -705,9 +708,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
} }
fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, qkv_layout,
attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, input_K, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q,
input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV, input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV,
output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else #else
...@@ -721,7 +724,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -721,7 +724,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const Tensor *input_M = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[0]); const Tensor *input_M = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[1]); const Tensor *input_ZInv = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[2]); const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O,
input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ,
output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv,
......
...@@ -48,11 +48,11 @@ ...@@ -48,11 +48,11 @@
namespace transformer_engine { namespace transformer_engine {
namespace fused_attn { namespace fused_attn {
void fused_attn_arbitrary_seqlen_fwd_impl( void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t bias_b, int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
int64_t bias_h, bool is_training, float scaling_factor, float dropout_probability, int64_t bias_b, int64_t bias_h, bool is_training, float scaling_factor,
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type,
int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void *devPtrQ,
void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO, void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV,
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
...@@ -86,7 +86,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -86,7 +86,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
hg, hg,
s_q, s_q,
s_kv, s_kv,
d, d_qk,
d_v,
bias_b, bias_b,
bias_h, bias_h,
scaling_factor, scaling_factor,
...@@ -167,41 +168,41 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -167,41 +168,41 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::vector<int64_t> q_stride(4); std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4); std::vector<int64_t> k_stride(4);
std::vector<int64_t> v_stride(4); std::vector<int64_t> v_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_Q_Matrix); NVTE_QKV_Matrix::NVTE_Q_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_K_Matrix); NVTE_QKV_Matrix::NVTE_K_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_V_Matrix); NVTE_QKV_Matrix::NVTE_V_Matrix);
if (is_ragged) { if (is_ragged) {
Q = mha_graph->tensor(fe::graph::Tensor_attributes() Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q") .set_name("Q")
.set_dim({b, h, s_q, d}) .set_dim({b, h, s_q, d_qk})
.set_stride(q_stride) .set_stride(q_stride)
.set_ragged_offset(offset_q)); .set_ragged_offset(offset_q));
K = mha_graph->tensor(fe::graph::Tensor_attributes() K = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K") .set_name("K")
.set_dim({b, hg, s_kv, d}) .set_dim({b, hg, s_kv, d_qk})
.set_stride(k_stride) .set_stride(k_stride)
.set_ragged_offset(offset_k)); .set_ragged_offset(offset_k));
V = mha_graph->tensor(fe::graph::Tensor_attributes() V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V") .set_name("V")
.set_dim({b, hg, s_kv, d}) .set_dim({b, hg, s_kv, d_v})
.set_stride(v_stride) .set_stride(v_stride)
.set_ragged_offset(offset_v)); .set_ragged_offset(offset_v));
} else { } else {
Q = mha_graph->tensor(fe::graph::Tensor_attributes() Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q") .set_name("Q")
.set_dim({b, h, s_q, d}) .set_dim({b, h, s_q, d_qk})
.set_stride(q_stride)); .set_stride(q_stride));
K = mha_graph->tensor(fe::graph::Tensor_attributes() K = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K") .set_name("K")
.set_dim({b, hg, s_kv, d}) .set_dim({b, hg, s_kv, d_qk})
.set_stride(k_stride)); .set_stride(k_stride));
V = mha_graph->tensor(fe::graph::Tensor_attributes() V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V") .set_name("V")
.set_dim({b, hg, s_kv, d}) .set_dim({b, hg, s_kv, d_v})
.set_stride(v_stride)); .set_stride(v_stride));
} }
...@@ -265,15 +266,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -265,15 +266,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options); auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options);
std::vector<int64_t> o_stride(4); std::vector<int64_t> o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_O_Matrix); NVTE_QKV_Matrix::NVTE_O_Matrix);
if (is_ragged) { if (is_ragged) {
O->set_output(true) O->set_output(true)
.set_dim({b, h, s_q, d}) .set_dim({b, h, s_q, d_v})
.set_stride(o_stride) .set_stride(o_stride)
.set_ragged_offset(offset_o); .set_ragged_offset(offset_o);
} else { } else {
O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride);
} }
Stats->set_output(true) Stats->set_output(true)
...@@ -360,7 +361,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -360,7 +361,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
void *devOffsetsO = static_cast<int8_t *>(devOffsetsV) + (b + 1) * sizeof(int32_t); void *devOffsetsO = static_cast<int8_t *>(devOffsetsV) + (b + 1) * sizeof(int32_t);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>( cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>(
layout_group, b, h, hg, d, static_cast<int32_t *>(devPtrSeqOffsetsQ), layout_group, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), static_cast<int32_t *>(devOffsetsQ), static_cast<int32_t *>(devPtrSeqOffsetsKV), static_cast<int32_t *>(devOffsetsQ),
static_cast<int32_t *>(devOffsetsK), static_cast<int32_t *>(devOffsetsV), static_cast<int32_t *>(devOffsetsK), static_cast<int32_t *>(devOffsetsV),
static_cast<int32_t *>(devOffsetsO)); static_cast<int32_t *>(devOffsetsO));
...@@ -381,13 +382,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -381,13 +382,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
} }
void fused_attn_arbitrary_seqlen_bwd_impl( void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t bias_b, int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
int64_t bias_h, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, int64_t bias_b, int64_t bias_h, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose, int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ,
void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats,
void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, void *devPtrBias, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrdBias, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV,
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
...@@ -419,7 +420,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -419,7 +420,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
hg, hg,
s_q, s_q,
s_kv, s_kv,
d, d_qk,
d_v,
bias_b, bias_b,
bias_h, bias_h,
scaling_factor, scaling_factor,
...@@ -505,61 +507,61 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -505,61 +507,61 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::vector<int64_t> k_stride(4); std::vector<int64_t> k_stride(4);
std::vector<int64_t> v_stride(4); std::vector<int64_t> v_stride(4);
std::vector<int64_t> o_stride(4); std::vector<int64_t> o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_Q_Matrix); NVTE_QKV_Matrix::NVTE_Q_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_K_Matrix); NVTE_QKV_Matrix::NVTE_K_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_V_Matrix); NVTE_QKV_Matrix::NVTE_V_Matrix);
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_O_Matrix); NVTE_QKV_Matrix::NVTE_O_Matrix);
if (is_ragged) { if (is_ragged) {
q = mha_graph->tensor(fe::graph::Tensor_attributes() q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q") .set_name("Q")
.set_dim({b, h, s_q, d}) .set_dim({b, h, s_q, d_qk})
.set_stride(q_stride) .set_stride(q_stride)
.set_ragged_offset(offset_q)); .set_ragged_offset(offset_q));
k = mha_graph->tensor(fe::graph::Tensor_attributes() k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K") .set_name("K")
.set_dim({b, hg, s_kv, d}) .set_dim({b, hg, s_kv, d_qk})
.set_stride(k_stride) .set_stride(k_stride)
.set_ragged_offset(offset_k)); .set_ragged_offset(offset_k));
v = mha_graph->tensor(fe::graph::Tensor_attributes() v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V") .set_name("V")
.set_dim({b, hg, s_kv, d}) .set_dim({b, hg, s_kv, d_v})
.set_stride(v_stride) .set_stride(v_stride)
.set_ragged_offset(offset_v)); .set_ragged_offset(offset_v));
o = mha_graph->tensor(fe::graph::Tensor_attributes() o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O") .set_name("O")
.set_dim({b, h, s_q, d}) .set_dim({b, h, s_q, d_v})
.set_stride(o_stride) .set_stride(o_stride)
.set_ragged_offset(offset_o)); .set_ragged_offset(offset_o));
dO = mha_graph->tensor(fe::graph::Tensor_attributes() dO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dO") .set_name("dO")
.set_dim({b, h, s_q, d}) .set_dim({b, h, s_q, d_v})
.set_stride(o_stride) .set_stride(o_stride)
.set_ragged_offset(offset_o)); .set_ragged_offset(offset_o));
} else { } else {
q = mha_graph->tensor(fe::graph::Tensor_attributes() q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q") .set_name("Q")
.set_dim({b, h, s_q, d}) .set_dim({b, h, s_q, d_qk})
.set_stride(q_stride)); .set_stride(q_stride));
k = mha_graph->tensor(fe::graph::Tensor_attributes() k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K") .set_name("K")
.set_dim({b, hg, s_kv, d}) .set_dim({b, hg, s_kv, d_qk})
.set_stride(k_stride)); .set_stride(k_stride));
v = mha_graph->tensor(fe::graph::Tensor_attributes() v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V") .set_name("V")
.set_dim({b, hg, s_kv, d}) .set_dim({b, hg, s_kv, d_v})
.set_stride(v_stride)); .set_stride(v_stride));
o = mha_graph->tensor(fe::graph::Tensor_attributes() o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O") .set_name("O")
.set_dim({b, h, s_q, d}) .set_dim({b, h, s_q, d_v})
.set_stride(o_stride)); .set_stride(o_stride));
dO = mha_graph->tensor(fe::graph::Tensor_attributes() dO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dO") .set_name("dO")
.set_dim({b, h, s_q, d}) .set_dim({b, h, s_q, d_v})
.set_stride(o_stride)); .set_stride(o_stride));
} }
stats = mha_graph->tensor(fe::graph::Tensor_attributes() stats = mha_graph->tensor(fe::graph::Tensor_attributes()
...@@ -644,21 +646,21 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -644,21 +646,21 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
if (is_ragged) { if (is_ragged) {
dQ->set_output(true) dQ->set_output(true)
.set_dim({b, h, s_q, d}) .set_dim({b, h, s_q, d_qk})
.set_stride(q_stride) .set_stride(q_stride)
.set_ragged_offset(offset_q); .set_ragged_offset(offset_q);
dK->set_output(true) dK->set_output(true)
.set_dim({b, hg, s_kv, d}) .set_dim({b, hg, s_kv, d_qk})
.set_stride(k_stride) .set_stride(k_stride)
.set_ragged_offset(offset_k); .set_ragged_offset(offset_k);
dV->set_output(true) dV->set_output(true)
.set_dim({b, hg, s_kv, d}) .set_dim({b, hg, s_kv, d_v})
.set_stride(v_stride) .set_stride(v_stride)
.set_ragged_offset(offset_v); .set_ragged_offset(offset_v);
} else { } else {
dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride(q_stride); dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride);
dK->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(k_stride); dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride);
dV->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(v_stride); dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(v_stride);
} }
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // q std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // q
...@@ -758,7 +760,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -758,7 +760,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
void *devOffsetsO = static_cast<int8_t *>(devOffsetsV) + (b + 1) * sizeof(int32_t); void *devOffsetsO = static_cast<int8_t *>(devOffsetsV) + (b + 1) * sizeof(int32_t);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>( cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>(
layout_group, b, h, hg, d, static_cast<int32_t *>(devPtrSeqOffsetsQ), layout_group, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), static_cast<int32_t *>(devOffsetsQ), static_cast<int32_t *>(devPtrSeqOffsetsKV), static_cast<int32_t *>(devOffsetsQ),
static_cast<int32_t *>(devOffsetsK), static_cast<int32_t *>(devOffsetsV), static_cast<int32_t *>(devOffsetsK), static_cast<int32_t *>(devOffsetsV),
static_cast<int32_t *>(devOffsetsO)); static_cast<int32_t *>(devOffsetsO));
...@@ -865,11 +867,12 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -865,11 +867,12 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl( fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h, batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, bias_b,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream,
handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -941,11 +944,11 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( ...@@ -941,11 +944,11 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl( fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h, batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, bias_b,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left,
deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed,
devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
...@@ -1051,12 +1054,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -1051,12 +1054,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl( fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
stream, handle); &workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1131,12 +1134,13 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -1131,12 +1134,13 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl( fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left,
deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size,
stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1155,8 +1159,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -1155,8 +1159,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
...@@ -1233,12 +1237,12 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1233,12 +1237,12 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl( fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
stream, handle); &workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1257,7 +1261,7 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1257,7 +1261,7 @@ void fused_attn_arbitrary_seqlen_fwd(
void fused_attn_arbitrary_seqlen_bwd( void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
...@@ -1302,12 +1306,13 @@ void fused_attn_arbitrary_seqlen_bwd( ...@@ -1302,12 +1306,13 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl( fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left,
deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size,
stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
......
...@@ -58,8 +58,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -58,8 +58,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
...@@ -68,7 +68,7 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -68,7 +68,7 @@ void fused_attn_arbitrary_seqlen_fwd(
void fused_attn_arbitrary_seqlen_bwd( void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
......
...@@ -1679,6 +1679,7 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1679,6 +1679,7 @@ void fused_attn_fp8_fwd_impl_v1(
s_q, s_q,
s_kv, s_kv,
d, d,
d,
bias_b, bias_b,
bias_h, bias_h,
scaling_factor, scaling_factor,
...@@ -1976,6 +1977,7 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -1976,6 +1977,7 @@ void fused_attn_fp8_bwd_impl_v1(
s_q, s_q,
s_kv, s_kv,
d, d,
d,
bias_b, bias_b,
bias_h, bias_h,
scaling_factor, scaling_factor,
......
...@@ -363,29 +363,30 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu ...@@ -363,29 +363,30 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu
// convert cu_seqlens_padded to offsets // convert cu_seqlens_padded to offsets
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h, __global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h,
size_t hg, size_t d, int32_t *cu_seqlens_q_padded, size_t hg, size_t d_qk, size_t d_v,
int32_t *cu_seqlens_q_padded,
int32_t *cu_seqlens_kv_padded, int32_t *offsets_q, int32_t *cu_seqlens_kv_padded, int32_t *offsets_q,
int32_t *offsets_k, int32_t *offsets_v, int32_t *offsets_k, int32_t *offsets_v,
int32_t *offsets_o) { int32_t *offsets_o) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x; size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < b + 1) { if (tid < b + 1) {
offsets_o[tid] = h * d * cu_seqlens_q_padded[tid]; offsets_o[tid] = h * d_v * cu_seqlens_q_padded[tid];
switch (layout_group) { switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
offsets_q[tid] = h * d * cu_seqlens_q_padded[tid]; offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[tid];
offsets_k[tid] = hg * d * cu_seqlens_kv_padded[tid]; offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[tid];
offsets_v[tid] = offsets_k[tid]; offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[tid];
break; break;
case NVTE_QKV_Layout_Group::NVTE_3HD: case NVTE_QKV_Layout_Group::NVTE_3HD:
case NVTE_QKV_Layout_Group::NVTE_H3D: case NVTE_QKV_Layout_Group::NVTE_H3D:
offsets_q[tid] = 3 * h * d * cu_seqlens_q_padded[tid]; offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[tid];
offsets_k[tid] = offsets_q[tid]; offsets_k[tid] = offsets_q[tid];
offsets_v[tid] = offsets_q[tid]; offsets_v[tid] = offsets_q[tid];
break; break;
case NVTE_QKV_Layout_Group::NVTE_HD_2HD: case NVTE_QKV_Layout_Group::NVTE_HD_2HD:
case NVTE_QKV_Layout_Group::NVTE_HD_H2D: case NVTE_QKV_Layout_Group::NVTE_HD_H2D:
offsets_q[tid] = h * d * cu_seqlens_q_padded[tid]; offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[tid];
offsets_k[tid] = 2 * hg * d * cu_seqlens_kv_padded[tid]; offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[tid];
offsets_v[tid] = offsets_k[tid]; offsets_v[tid] = offsets_k[tid];
break; break;
} }
......
...@@ -91,7 +91,8 @@ struct FADescriptor_v1 { ...@@ -91,7 +91,8 @@ struct FADescriptor_v1 {
std::int64_t hg; std::int64_t hg;
std::int64_t s_q; std::int64_t s_q;
std::int64_t s_kv; std::int64_t s_kv;
std::int64_t d; std::int64_t d_qk;
std::int64_t d_v;
std::int64_t bias_b; std::int64_t bias_b;
std::int64_t bias_h; std::int64_t bias_h;
float attnScale; float attnScale;
...@@ -107,11 +108,11 @@ struct FADescriptor_v1 { ...@@ -107,11 +108,11 @@ struct FADescriptor_v1 {
cudnn_frontend::DataType_t bwd_tensor_type; cudnn_frontend::DataType_t bwd_tensor_type;
bool operator<(const FADescriptor_v1 &rhs) const { bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h, attnScale, isTraining, return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, bias_b, bias_h, attnScale, isTraining,
dropoutProbability, layout, mask_type, window_size_left, window_size_right, dropoutProbability, layout, mask_type, window_size_left, window_size_right,
deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d, rhs.bias_b, rhs.bias_h, std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.bias_b,
rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic,
rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type); rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type);
} }
...@@ -126,7 +127,8 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu ...@@ -126,7 +127,8 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu
int32_t *kv_seqlens); int32_t *kv_seqlens);
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h, __global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h,
size_t hg, size_t d, int32_t *cu_seqlens_q_padded, size_t hg, size_t d_qk, size_t d_v,
int32_t *cu_seqlens_q_padded,
int32_t *cu_seqlens_kv_padded, int32_t *offsets_q, int32_t *cu_seqlens_kv_padded, int32_t *offsets_q,
int32_t *offsets_k, int32_t *offsets_v, int32_t *offsets_k, int32_t *offsets_v,
int32_t *offsets_o); int32_t *offsets_o);
......
...@@ -147,15 +147,16 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout); ...@@ -147,15 +147,16 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] num_gqa_groups The number of heads in K, V. * \param[in] num_gqa_groups The number of heads in K, V.
* \param[in] max_seqlen_q The sequence length of Q. * \param[in] max_seqlen_q The sequence length of Q.
* \param[in] max_seqlen_kv The sequence length of K, V. * \param[in] max_seqlen_kv The sequence length of K, V.
* \param[in] head_dim The head dimension of Q, K, V. * \param[in] head_dim_qk The head dimension of Q, K.
* \param[in] head_dim_v The head dimension of V.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
*/ */
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v,
int64_t window_size_right); int64_t window_size_left, int64_t window_size_right);
/*! \brief Compute dot product attention with packed QKV input. /*! \brief Compute dot product attention with packed QKV input.
* *
......
...@@ -19,7 +19,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, ...@@ -19,7 +19,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen,
head_dim, -1, -1); head_dim, head_dim, -1, -1);
return backend; return backend;
} }
...@@ -255,10 +255,10 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -255,10 +255,10 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
/* Prepare RNG state */ /* Prepare RNG state */
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64); auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = auto backend = nvte_get_fused_attn_backend(
nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, -1, -1); head_dim, head_dim, -1, -1);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */ /* Auxiliary tensors (to be propagated to the backward pass later) */
...@@ -486,10 +486,10 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -486,10 +486,10 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
/* Auxiliary tensors (propagated from the forward pass) */ /* Auxiliary tensors (propagated from the forward pass) */
NVTETensorPack aux_input_tensors; NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors); nvte_tensor_pack_create(&aux_input_tensors);
auto backend = auto backend = nvte_get_fused_attn_backend(
nvte_get_fused_attn_backend(static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, -1, -1); head_dim, head_dim, -1, -1);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
rng_state, bias); rng_state, bias);
......
...@@ -131,10 +131,10 @@ inline NVTE_Fused_Attn_Backend get_fused_attn_backend( ...@@ -131,10 +131,10 @@ inline NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim) { size_t max_seqlen_kv, size_t head_dim) {
NVTE_Fused_Attn_Backend fused_attention_backend = NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
nvte_get_fused_attn_backend(static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv,
num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, -1, -1); head_dim, head_dim, -1, -1);
return fused_attention_backend; return fused_attention_backend;
} }
......
...@@ -142,8 +142,10 @@ class AttentionParams: ...@@ -142,8 +142,10 @@ class AttentionParams:
Maximum sequence length of the query tensor. Maximum sequence length of the query tensor.
max_seqlen_kv: int, default = 128 max_seqlen_kv: int, default = 128
Maximum sequence length of the key and value tensors. Maximum sequence length of the key and value tensors.
head_dim: int, default = 64 head_dim_qk: int, default = 64
The size of each attention head. The size of each attention head in query and key tensors.
head_dim_v: int, default = 64
The size of each attention head in the value tensor.
attn_mask_type: str, default = `no_mask` attn_mask_type: str, default = `no_mask`
Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`, Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
`causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
...@@ -182,7 +184,8 @@ class AttentionParams: ...@@ -182,7 +184,8 @@ class AttentionParams:
num_gqa_groups: int = 16 num_gqa_groups: int = 16
max_seqlen_q: int = 128 max_seqlen_q: int = 128
max_seqlen_kv: int = 128 max_seqlen_kv: int = 128
head_dim: int = 64 head_dim_qk: int = 64
head_dim_v: int = 64
attn_mask_type: str = "no_mask" attn_mask_type: str = "no_mask"
window_size: Union[Tuple[int, int], None] = None window_size: Union[Tuple[int, int], None] = None
alibi_slopes_shape: Union[torch.Size, List, None] = None alibi_slopes_shape: Union[torch.Size, List, None] = None
...@@ -245,7 +248,8 @@ def get_attention_backend( ...@@ -245,7 +248,8 @@ def get_attention_backend(
num_gqa_groups = attention_params.num_gqa_groups num_gqa_groups = attention_params.num_gqa_groups
max_seqlen_q = attention_params.max_seqlen_q max_seqlen_q = attention_params.max_seqlen_q
max_seqlen_kv = attention_params.max_seqlen_kv max_seqlen_kv = attention_params.max_seqlen_kv
head_dim = attention_params.head_dim head_dim_qk = attention_params.head_dim_qk
head_dim_v = attention_params.head_dim_v
attn_mask_type = attention_params.attn_mask_type attn_mask_type = attention_params.attn_mask_type
window_size = attention_params.window_size window_size = attention_params.window_size
alibi_slopes_shape = attention_params.alibi_slopes_shape alibi_slopes_shape = attention_params.alibi_slopes_shape
...@@ -352,19 +356,31 @@ def get_attention_backend( ...@@ -352,19 +356,31 @@ def get_attention_backend(
use_unfused_attention = False use_unfused_attention = False
# Filter: Head dimension # Filter: Head dimension
if use_flash_attention and head_dim_qk != head_dim_v:
logger.debug("Disabling FlashAttention as it does not support MLA.")
use_flash_attention = False
if use_flash_attention and ( if use_flash_attention and (
head_dim > 256 head_dim_qk > 256
or head_dim % 8 != 0 or head_dim_qk % 8 != 0
or (head_dim > 192 and device_compute_capability not in ((8, 0), (9, 0))) or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0)))
): ):
logger.debug( logger.debug(
"Disabling FlashAttention due to unsupported head_dim. " "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. "
"Supported: head_dim %%8 = 0, head_dim <= 256 (>192 requires sm80/90). " "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
"Found: head_dim = %s on sm%s.", "head_dim_qk <= 256 (>192 requires sm80/90). "
head_dim, "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
head_dim_qk,
head_dim_v,
".".join([str(i) for i in device_compute_capability]), ".".join([str(i) for i in device_compute_capability]),
) )
use_flash_attention = False use_flash_attention = False
qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd":
logger.debug(
"Disabling FusedAttention as MLA is not supported with qkv_layout = %s",
qkv_layout,
)
use_fused_attention = False
# Filter: QKV layout # Filter: QKV layout
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
...@@ -557,7 +573,8 @@ def get_attention_backend( ...@@ -557,7 +573,8 @@ def get_attention_backend(
num_gqa_groups, num_gqa_groups,
max_seqlen_q, max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
head_dim, head_dim_qk,
head_dim_v,
window_size[0], window_size[0],
window_size[1], window_size[1],
) )
...@@ -3132,12 +3149,14 @@ def get_qkv_layout( ...@@ -3132,12 +3149,14 @@ def get_qkv_layout(
stride = q.stride() stride = q.stride()
check_strides_qkv = all(stride == x.stride() for x in [q, k, v]) check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
stride = k.stride() stride = k.stride()
check_strides_kv = all(stride == x.stride() for x in [k, v]) check_strides_kv = torch.equal(
torch.Tensor(stride[:-1]) / k.shape[-1], torch.Tensor(v.stride()[:-1]) / v.shape[-1]
)
shape = q.shape shape = q.shape
check_shapes_qkv = all(shape == x.shape for x in [q, k, v]) check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
shape = k.shape shape = k.shape
check_shapes_kv = all(shape == x.shape for x in [k, v]) check_shapes_kv = shape[:-1] == v.shape[:-1]
last_dim_size = q.shape[-1] last_dim_size = q.shape[-1]
check_last_dim_offsets_qkv = all( check_last_dim_offsets_qkv = all(
...@@ -5177,8 +5196,10 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5177,8 +5196,10 @@ class DotProductAttention(TransformerEngineBaseModule):
---------- ----------
num_attention_heads : int num_attention_heads : int
number of attention heads in the transformer layer. number of attention heads in the transformer layer.
kv_channels : int k_channels : int
number of key-query-value channels per attention head. number of channels per attention head in key.
v_channels : Optional[int] = None
number of channels per attention head in value.
num_gqa_groups : Optional[int] = None num_gqa_groups : Optional[int] = None
number of GQA groups in the transformer layer. number of GQA groups in the transformer layer.
Grouped Query Attention is described in Grouped Query Attention is described in
...@@ -5264,7 +5285,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5264,7 +5285,8 @@ class DotProductAttention(TransformerEngineBaseModule):
def __init__( def __init__(
self, self,
num_attention_heads: int, num_attention_heads: int,
kv_channels: int, k_channels: int,
v_channels: Optional[int] = None,
num_gqa_groups: Optional[int] = None, num_gqa_groups: Optional[int] = None,
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
qkv_format: str = "sbhd", qkv_format: str = "sbhd",
...@@ -5304,7 +5326,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5304,7 +5326,8 @@ class DotProductAttention(TransformerEngineBaseModule):
self.cp_global_ranks = cp_global_ranks self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream self.cp_stream = cp_stream
self.hidden_size_per_attention_head = kv_channels self.hidden_size_per_attention_head = k_channels
self.v_channels = k_channels if v_channels is None else v_channels
self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size) self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
...@@ -5322,7 +5345,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5322,7 +5345,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attention_dropout_ctx = self.rng_states_tracker.fork attention_dropout_ctx = self.rng_states_tracker.fork
if softmax_scale is None: if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(kv_channels) softmax_scale = 1.0 / math.sqrt(k_channels)
self.deterministic = ( self.deterministic = (
not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
...@@ -5469,16 +5492,6 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5469,16 +5492,6 @@ class DotProductAttention(TransformerEngineBaseModule):
Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type` Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
includes '"padding"' or `"arbitrary"`. includes '"padding"' or `"arbitrary"`.
.. note::
Input tensor :attr:`query_layer` must be of shape
(:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`,
:attr:`kv_channels`) and the tensors :attr:`key_layer` and :attr:`value_layer`
must each be of shape (:attr:`sequence_length`, :attr:`batch_size`,
:attr:`num_gqa_groups`, :attr:`kv_channels`). Output of shape
(:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`
* :attr:`kv_channels`) is returned.
.. note:: .. note::
DotProductAttention supports three backends: 1) FlashAttention which calls DotProductAttention supports three backends: 1) FlashAttention which calls
...@@ -5628,7 +5641,9 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5628,7 +5641,9 @@ class DotProductAttention(TransformerEngineBaseModule):
assert ( assert (
query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype
), "Queries, keys and values must have the same data type!" ), "Queries, keys and values must have the same data type!"
assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!" assert (
key_layer.shape[:-1] == value_layer.shape[:-1]
), "Keys and values must have the same batch size, sequence length and number of heads!"
if attn_mask_type is None: if attn_mask_type is None:
attn_mask_type = self.attn_mask_type attn_mask_type = self.attn_mask_type
...@@ -5861,7 +5876,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5861,7 +5876,8 @@ class DotProductAttention(TransformerEngineBaseModule):
num_gqa_groups=key_layer.shape[-2], num_gqa_groups=key_layer.shape[-2],
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
head_dim=query_layer.shape[-1], head_dim_qk=query_layer.shape[-1],
head_dim_v=value_layer.shape[-1],
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
window_size=window_size, window_size=window_size,
alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None, alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
......
...@@ -140,7 +140,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -140,7 +140,7 @@ def fused_attn_fwd_qkvpacked(
output tensor, amax of O, used by the next iteration in FP8 computations output tensor, amax of O, used by the next iteration in FP8 computations
attn_scale: float, default = None attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM; if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default if None, use 1.0/sqrt(head_dim_qk) as the default
dropout: float, default = 0.0 dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output; dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False dropout must be 0.0 if is_training is False
...@@ -342,7 +342,7 @@ def fused_attn_bwd_qkvpacked( ...@@ -342,7 +342,7 @@ def fused_attn_bwd_qkvpacked(
output tensor, amax of dQKV, used by the next iteration in FP8 computations output tensor, amax of dQKV, used by the next iteration in FP8 computations
attn_scale: float, default = None attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM; if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default if None, use 1.0/sqrt(head_dim_qk) as the default
dropout: float, default = 0.0 dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output; dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False dropout must be 0.0 if is_training is False
...@@ -508,7 +508,7 @@ def fused_attn_fwd_kvpacked( ...@@ -508,7 +508,7 @@ def fused_attn_fwd_kvpacked(
output tensor, amax of O, used by the next iteration in FP8 computations output tensor, amax of O, used by the next iteration in FP8 computations
attn_scale: float, default = None attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM; if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default if None, use 1.0/sqrt(head_dim_qk) as the default
dropout: float, default = 0.0 dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output; dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False dropout must be 0.0 if is_training is False
...@@ -729,7 +729,7 @@ def fused_attn_bwd_kvpacked( ...@@ -729,7 +729,7 @@ def fused_attn_bwd_kvpacked(
output tensor, amax of dQKV, used by the next iteration in FP8 computations output tensor, amax of dQKV, used by the next iteration in FP8 computations
attn_scale: float, default = None attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM; if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default if None, use 1.0/sqrt(head_dim_qk) as the default
dropout: float, default = 0.0 dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output; dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False dropout must be 0.0 if is_training is False
...@@ -907,7 +907,7 @@ def fused_attn_fwd( ...@@ -907,7 +907,7 @@ def fused_attn_fwd(
output tensor, amax of O, used by the next iteration in FP8 computations output tensor, amax of O, used by the next iteration in FP8 computations
attn_scale: float, default = None attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM; if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default if None, use 1.0/sqrt(head_dim_qk) as the default
dropout: float, default = 0.0 dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output; dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False dropout must be 0.0 if is_training is False
...@@ -1135,7 +1135,7 @@ def fused_attn_bwd( ...@@ -1135,7 +1135,7 @@ def fused_attn_bwd(
output tensor, amax of dQ, dK and dV, used by the next iteration in FP8 computations output tensor, amax of dQ, dK and dV, used by the next iteration in FP8 computations
attn_scale: float, default = None attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM; if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim) as the default if None, use 1.0/sqrt(head_dim_qk) as the default
dropout: float, default = 0.0 dropout: float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output; dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False dropout must be 0.0 if is_training is False
......
...@@ -14,11 +14,14 @@ ...@@ -14,11 +14,14 @@
* Attention * Attention
**************************************************************************************************/ **************************************************************************************************/
NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype,
const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, NVTE_Mask_Type attn_mask_type, float p_dropout,
size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, int64_t window_size_right); size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv,
size_t head_dim_qk, size_t head_dim_v,
int64_t window_size_left, int64_t window_size_right);
std::vector<at::Tensor> fused_attn_fwd_qkvpacked( std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero,
......
...@@ -14,11 +14,12 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( ...@@ -14,11 +14,12 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, int64_t window_size_right) { size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) {
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv,
head_dim, window_size_left, window_size_right); head_dim_qk, head_dim_v, window_size_left, window_size_right);
return fused_attention_backend; return fused_attention_backend;
} }
...@@ -761,7 +762,11 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -761,7 +762,11 @@ std::vector<at::Tensor> fused_attn_fwd(
std::vector<size_t> v_shape{v_sizes.begin(), v_sizes.end()}; std::vector<size_t> v_shape{v_sizes.begin(), v_sizes.end()};
// create output tensor O // create output tensor O
auto O = torch::empty_like(Q); auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
auto o_shape = std::vector<size_t>{q_sizes.begin(), q_sizes.end()};
o_shape[o_shape.size() - 1] = v_sizes[v_sizes.size() - 1];
std::vector<int64_t> o_shape_tmp{o_shape.begin(), o_shape.end()};
auto O = torch::empty(c10::IntArrayRef(o_shape_tmp), options);
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias; TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias;
...@@ -790,7 +795,7 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -790,7 +795,7 @@ std::vector<at::Tensor> fused_attn_fwd(
descale_QKV.value().data_ptr()); descale_QKV.value().data_ptr());
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(),
scale_S.value().data_ptr(), descale_S.value().data_ptr()); scale_S.value().data_ptr(), descale_S.value().data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, amax_O.value().data_ptr(),
scale_O.value().data_ptr(), nullptr); scale_O.value().data_ptr(), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
...@@ -801,7 +806,7 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -801,7 +806,7 @@ std::vector<at::Tensor> fused_attn_fwd(
te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr);
te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr); te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, nullptr);
} else { } else {
NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n");
} }
...@@ -839,8 +844,7 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -839,8 +844,7 @@ std::vector<at::Tensor> fused_attn_fwd(
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
auto options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
auto rng_state = torch::empty({2}, options);
unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
philox_args, static_cast<int64_t *>(rng_state.data_ptr())); philox_args, static_cast<int64_t *>(rng_state.data_ptr()));
auto te_rng_state = makeTransformerEngineTensor(rng_state); auto te_rng_state = makeTransformerEngineTensor(rng_state);
...@@ -935,8 +939,11 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -935,8 +939,11 @@ std::vector<at::Tensor> fused_attn_bwd(
std::vector<size_t> v_shape{v_sizes.begin(), v_sizes.end()}; std::vector<size_t> v_shape{v_sizes.begin(), v_sizes.end()};
auto h_q = q_shape[q_shape.size() - 2]; auto h_q = q_shape[q_shape.size() - 2];
auto h_kv = k_shape[k_shape.size() - 2]; auto h_kv = k_shape[k_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1]; auto d_qk = q_shape[q_shape.size() - 1];
auto d_v = v_shape[v_shape.size() - 1];
auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA);
std::vector<size_t> o_shape{q_sizes.begin(), q_sizes.end()};
o_shape[o_shape.size() - 1] = d_v;
at::Tensor dQ; at::Tensor dQ;
at::Tensor dK; at::Tensor dK;
...@@ -1015,7 +1022,7 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1015,7 +1022,7 @@ std::vector<at::Tensor> fused_attn_bwd(
TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
if (set_zero && ((h_q * d) % block_size == 0) && ((h_kv * d) % block_size == 0) && if (set_zero && ((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) &&
dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous() && dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous() &&
(nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
...@@ -1041,9 +1048,9 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1041,9 +1048,9 @@ std::vector<at::Tensor> fused_attn_bwd(
descale_QKV.value().data_ptr()); descale_QKV.value().data_ptr());
te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr,
descale_QKV.value().data_ptr()); descale_QKV.value().data_ptr());
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr,
descale_O.value().data_ptr()); descale_O.value().data_ptr());
te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, te_dO = makeTransformerEngineTensor(dO.data_ptr(), o_shape, dqkv_type, nullptr, nullptr,
descale_dO.value().data_ptr()); descale_dO.value().data_ptr());
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr,
scale_S.value().data_ptr(), descale_S.value().data_ptr()); scale_S.value().data_ptr(), descale_S.value().data_ptr());
...@@ -1068,9 +1075,9 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1068,9 +1075,9 @@ std::vector<at::Tensor> fused_attn_bwd(
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr);
te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr);
te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr); te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr);
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, nullptr);
te_dO = te_dO =
makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); makeTransformerEngineTensor(dO.data_ptr(), o_shape, dqkv_type, nullptr, nullptr, nullptr);
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr);
te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr);
te_dQ = te_dQ =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment