Unverified Commit 71c76b6b authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Add support for head_dim > 128 (#1797)



* add support for head dim > 128
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* remove debugging
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* raise tols slightly to tolerate 1/2048 mismatches
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix is_training for test_te_layer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add bprop support for blackwell
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor tweak for format
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix backend selection results
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* bump sm100 to sm100+
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add sq=1 test for MLA
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* enable sq=1 for bprop
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* minor tweak in comments
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix head_dim logic and remove pytest skip
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add FE fix for d>128
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* update FE again to take in small fixes
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add cuDNN version info in L0 tests
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* increase tols for Unfused + large dim
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Revert "add cuDNN version info in L0 tests"

This reverts commit 3e1b426ca5319a2c0540b9e73bba7047d0e583e5.
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix tols for Unfused
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

---------
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 7b94bd99
Subproject commit 724f0ec8ce06027feada51f2d948cd3313e63720 Subproject commit f937055efc6d414d11f4c6577e3977fe74f35fb6
...@@ -68,6 +68,7 @@ class TestDistributedSelfAttn: ...@@ -68,6 +68,7 @@ class TestDistributedSelfAttn:
batch, seqlen, num_head, hidden = data_shape batch, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available( if not is_fused_attn_kernel_available(
is_training,
dtype, dtype,
dtype, dtype,
QKVLayout.BS3HD, QKVLayout.BS3HD,
...@@ -214,6 +215,7 @@ class TestDistributedCrossAttn: ...@@ -214,6 +215,7 @@ class TestDistributedCrossAttn:
batch, seqlen, num_head, hidden = data_shape batch, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available( if not is_fused_attn_kernel_available(
is_training,
dtype, dtype,
dtype, dtype,
QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BS2HD,
...@@ -346,6 +348,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -346,6 +348,7 @@ class TestDistributedContextParallelSelfAttn:
def check_has_backend_for_mask(mask_type): def check_has_backend_for_mask(mask_type):
return is_fused_attn_kernel_available( return is_fused_attn_kernel_available(
is_training,
dtype, dtype,
dtype, dtype,
qkv_layout, qkv_layout,
......
...@@ -347,6 +347,7 @@ class FusedAttnRunner: ...@@ -347,6 +347,7 @@ class FusedAttnRunner:
) )
self.backend = FusedAttnHelper( self.backend = FusedAttnHelper(
self.is_training,
self.dtype, self.dtype,
self.dtype, self.dtype,
self.qkv_layout, self.qkv_layout,
......
...@@ -222,13 +222,19 @@ def _get_attention_backends( ...@@ -222,13 +222,19 @@ def _get_attention_backends(
model_configs_base = { model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend # test: b, h, hg, d, sq, skv, p, mask, bias
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0 "base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0 "base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
"base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1 "base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1 "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference "base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference "base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_4_0": ModelConfig(8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_4_1": ModelConfig(8, 16, 16, 192, 128, 2048, 0.0, "no_mask", "no_bias"),
"base_5_0": ModelConfig(8, 16, 16, 512, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_5_1": ModelConfig(8, 16, 16, 512, 128, 2048, 0.0, "no_mask", "no_bias"),
"base_6_0": ModelConfig(8, 16, 16, 1024, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_6_1": ModelConfig(8, 16, 16, 1024, 128, 2048, 0.0, "no_mask", "no_bias"),
} }
...@@ -270,14 +276,28 @@ def test_dot_product_attention( ...@@ -270,14 +276,28 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa: if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2] config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
is_training = True
available_backends, _, fused_attn_backends = _get_attention_backends( available_backends, _, fused_attn_backends = _get_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
window_size=config.window_size, window_size=config.window_size,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
is_training=is_training,
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
is_training = False
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
# mannually pads and unpads the input and output of FlashAttention for testing purposes # mannually pads and unpads the input and output of FlashAttention for testing purposes
...@@ -296,7 +316,6 @@ def test_dot_product_attention( ...@@ -296,7 +316,6 @@ def test_dot_product_attention(
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.") pytest.skip("Less than two backends to compare.")
is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128
# UnfusedDotProductAttention backend # UnfusedDotProductAttention backend
if unfused_attn_supported: if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
...@@ -360,6 +379,7 @@ def test_dot_product_attention( ...@@ -360,6 +379,7 @@ def test_dot_product_attention(
is_training, is_training,
) )
logging.info(f"[test_dot_product_attention]: is_training = {is_training}")
if unfused_attn_supported and flash_attn_supported: if unfused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs flash attn") logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
...@@ -399,18 +419,27 @@ model_configs_mla = { ...@@ -399,18 +419,27 @@ model_configs_mla = {
"mla_1_1": ModelConfig( "mla_1_1": ModelConfig(
4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128 4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0 ), # cross, 0
"mla_1_2": ModelConfig(
4, 16, 16, 192, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_2_0": ModelConfig( "mla_2_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64 2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64
), # self , 1 ), # self , 1
"mla_2_1": ModelConfig( "mla_2_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64 1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # cross, 1 ), # cross, 1
"mla_2_2": ModelConfig(
1, 24, 24, 192, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=128
), # cross, 1
"mla_3_0": ModelConfig( "mla_3_0": ModelConfig(
8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64 8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64
), # inference ), # inference
"mla_3_1": ModelConfig( "mla_3_1": ModelConfig(
8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128 8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference ), # inference
"mla_3_2": ModelConfig(
8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference
} }
...@@ -1024,6 +1053,8 @@ def _run_dot_product_attention( ...@@ -1024,6 +1053,8 @@ def _run_dot_product_attention(
layer_number=1, layer_number=1,
attention_type=config.attn_type, attention_type=config.attn_type,
).to(dtype=dtype, device="cuda") ).to(dtype=dtype, device="cuda")
if not is_training:
block = block.eval()
# Run a forward and backward pass # Run a forward and backward pass
if backend in ["FlashAttention", "UnfusedDotProductAttention"]: if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
...@@ -1136,14 +1167,29 @@ def test_transformer_layer( ...@@ -1136,14 +1167,29 @@ def test_transformer_layer(
workspace_opt = True workspace_opt = True
# Test backend availability # Test backend availability
is_training = True
available_backends, _, fused_attn_backends = _get_attention_backends( available_backends, _, fused_attn_backends = _get_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=( qkv_layout=(
qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd") qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
), ),
is_training=is_training,
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
is_training = False
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=(
qkv_format.replace("hd", "h3d")
if fused_qkv_params
else qkv_format.replace("hd", "3hd")
),
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported # Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
...@@ -1163,6 +1209,7 @@ def test_transformer_layer( ...@@ -1163,6 +1209,7 @@ def test_transformer_layer(
workspace_opt, workspace_opt,
fused_qkv_params, fused_qkv_params,
RoPE, RoPE,
is_training,
) )
# FusedAttention backend # FusedAttention backend
...@@ -1176,6 +1223,7 @@ def test_transformer_layer( ...@@ -1176,6 +1223,7 @@ def test_transformer_layer(
workspace_opt, workspace_opt,
fused_qkv_params, fused_qkv_params,
RoPE, RoPE,
is_training,
) )
# FlashAttention backend # FlashAttention backend
...@@ -1189,8 +1237,10 @@ def test_transformer_layer( ...@@ -1189,8 +1237,10 @@ def test_transformer_layer(
workspace_opt, workspace_opt,
fused_qkv_params, fused_qkv_params,
RoPE, RoPE,
is_training,
) )
logging.info(f"[test_transformer_layer]: is_training = {is_training}")
if unfused_attn_supported and fused_attn_supported: if unfused_attn_supported and fused_attn_supported:
logging.info("[test_transformer_layer]: unfused attn vs fused attn") logging.info("[test_transformer_layer]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
...@@ -1257,6 +1307,7 @@ def _run_transformer_layer( ...@@ -1257,6 +1307,7 @@ def _run_transformer_layer(
workspace_opt: bool, workspace_opt: bool,
fused_qkv_params: bool, fused_qkv_params: bool,
RoPE: bool, RoPE: bool,
is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run TransformerLayer module with one forward pass and one backward pass""" """Run TransformerLayer module with one forward pass and one backward pass"""
...@@ -1410,6 +1461,8 @@ def _run_transformer_layer( ...@@ -1410,6 +1461,8 @@ def _run_transformer_layer(
bias=True, bias=True,
attn_input_format=qkv_format, attn_input_format=qkv_format,
).to(dtype=dtype, device="cuda") ).to(dtype=dtype, device="cuda")
if not is_training:
block = block.eval()
# Create ALiBi slopes # Create ALiBi slopes
alibi_slopes = None alibi_slopes = None
...@@ -1432,8 +1485,9 @@ def _run_transformer_layer( ...@@ -1432,8 +1485,9 @@ def _run_transformer_layer(
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
) )
loss = out.sum() if is_training:
loss.backward() loss = out.sum()
loss.backward()
return out, inp.grad return out, inp.grad
......
...@@ -52,7 +52,7 @@ model_configs_infer = { ...@@ -52,7 +52,7 @@ model_configs_infer = {
4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16 4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16
), ),
"infer_1": ModelConfig( "infer_1": ModelConfig(
2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16 2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
), ),
} }
...@@ -370,12 +370,24 @@ def generate_args( ...@@ -370,12 +370,24 @@ def generate_args(
] ]
def get_tols(module, backend, dtype): def get_tols(config, module, backend, dtype):
if module == "TransformerLayer": if module == "TransformerLayer":
tols = { if config.head_dim_qk <= 128:
torch.half: (5e-3, 5e-3), tols = {
torch.bfloat16: (3.5e-2, 3.5e-2), torch.half: (5e-3, 5e-3),
} torch.bfloat16: (3.5e-2, 3.5e-2),
}
else:
if backend == "UnfusedAttention":
tols = {
torch.half: (1.6e-2, 1.6e-2),
torch.bfloat16: (1.2e-1, 1e-1),
}
else:
tols = {
torch.half: (1e-2, 1e-2),
torch.bfloat16: (8e-2, 7e-2),
}
if module == "DotProductAttention": if module == "DotProductAttention":
tols = { tols = {
torch.half: (1e-3, 1e-3), torch.half: (1e-3, 1e-3),
...@@ -662,7 +674,9 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g ...@@ -662,7 +674,9 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
incremental_output = incremental_output[0] incremental_output = incremental_output[0]
# compare results # compare results
atol, rtol = get_tols(module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn) atol, rtol = get_tols(
config, module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn
)
for i, seq in enumerate(sim.t_seq_ids): for i, seq in enumerate(sim.t_seq_ids):
token_index = sim.step_lens[i] - 1 token_index = sim.step_lens[i] - 1
if qkv_format == "bshd": if qkv_format == "bshd":
......
...@@ -134,10 +134,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { ...@@ -134,10 +134,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {
// select a backend for fused attention // select a backend for fused attention
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
int64_t window_size_left, int64_t window_size_right) { size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) {
using namespace transformer_engine; using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device(); const int device_id = cuda::current_device();
...@@ -216,12 +216,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -216,12 +216,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
} }
if ( if (
// TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging
// special conditions for blackwell
// TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7
!(sm_arch_ >= 100 && (head_dim_qk > 128 || head_dim_v > 128)) &&
// architecture // architecture
((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || ((cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90)) ||
(cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && (cudnn_runtime_version >= 8903 && sm_arch_ >= 80 && sm_arch_ < 100) ||
(cudnn_runtime_version >= 90700 && sm_arch_ >= 80)) &&
// sequence length // sequence length
((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) || ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) ||
(cudnn_runtime_version >= 90000)) && (cudnn_runtime_version >= 90000)) &&
...@@ -229,11 +227,28 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -229,11 +227,28 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) ||
(cudnn_runtime_version >= 8907)) && (cudnn_runtime_version >= 8907)) &&
// head dimension // head dimension
((head_dim_qk <= 128 && head_dim_qk % 8 == 0 && head_dim_v <= 128 && head_dim_v % 8 == 0) || // multiples of 8
// TODO (cyang): add is_training to nvte_get_fused_attn_backend (head_dim_qk % 8 == 0 && head_dim_v % 8 == 0 &&
// d=256 only supported for forward // <= 128
(sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim_qk <= 256 && ((head_dim_qk <= 128 && head_dim_v <= 128) ||
head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) && // 9.1: <= 256 + Hopper + fprop
// 9.5: <= 256 + Hopper + bprop
(head_dim_qk <= 256 && head_dim_v <= 256 &&
((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) ||
(is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) ||
// 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1
(!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 &&
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) ||
// 9.10: any head_dim + any arch + fprop + paged
// 9.10: any head_dim + any arch + fprop + non_paged + sq > 1
// 9.10: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM}
(!is_training && cudnn_runtime_version >= 91000 &&
(layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 ||
(max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) ||
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
(head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 &&
cudnn_runtime_version >= 91100))) &&
// bias type // bias type
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
(cudnn_runtime_version >= 8906 && (cudnn_runtime_version >= 8906 &&
...@@ -423,8 +438,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -423,8 +438,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype); const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h,
max_seqlen, d, d, window_size_left, window_size_right); max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -505,7 +520,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -505,7 +520,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype); const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
max_seqlen, d, d, window_size_left, window_size_right); max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
...@@ -636,8 +651,8 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -636,8 +651,8 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_kv, d, d, window_size_left, window_size_right); max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -731,8 +746,8 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -731,8 +746,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_kv, d, d, window_size_left, window_size_right); max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -862,8 +877,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -862,8 +877,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -954,8 +969,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -954,8 +969,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
......
...@@ -172,6 +172,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); ...@@ -172,6 +172,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get fused attention backend based on input parameters. /*! \brief Get fused attention backend based on input parameters.
* *
* \param[in] is_training Whether the model is in training mode.
* \param[in] q_dtype The data type of Tensor Q. * \param[in] q_dtype The data type of Tensor Q.
* \param[in] kv_dtype The data type of Tensors K, V. * \param[in] kv_dtype The data type of Tensors K, V.
* \param[in] qkv_layout The layout of Tensors Q, K, V. * \param[in] qkv_layout The layout of Tensors Q, K, V.
...@@ -188,10 +189,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); ...@@ -188,10 +189,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
*/ */
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
int64_t window_size_left, int64_t window_size_right); size_t head_dim_v, int64_t window_size_left, int64_t window_size_right);
/*! \brief Compute dot product attention with packed QKV input. /*! \brief Compute dot product attention with packed QKV input.
* *
......
...@@ -277,6 +277,7 @@ def canonicalize_attn_mask_type(attn_mask_type: str): ...@@ -277,6 +277,7 @@ def canonicalize_attn_mask_type(attn_mask_type: str):
def is_fused_attn_kernel_available( def is_fused_attn_kernel_available(
is_training,
q_dtype, q_dtype,
kv_dtype, kv_dtype,
qkv_layout, qkv_layout,
...@@ -296,6 +297,7 @@ def is_fused_attn_kernel_available( ...@@ -296,6 +297,7 @@ def is_fused_attn_kernel_available(
def make_helper(attn_mask_type): def make_helper(attn_mask_type):
return tex.FusedAttnHelper( return tex.FusedAttnHelper(
is_training,
q_dtype, q_dtype,
kv_dtype, kv_dtype,
qkv_layout, qkv_layout,
......
...@@ -103,6 +103,7 @@ class FusedAttnHelper: ...@@ -103,6 +103,7 @@ class FusedAttnHelper:
Helper for the fused attention backend Helper for the fused attention backend
""" """
is_training: bool
q_dtype: jnp.dtype q_dtype: jnp.dtype
kv_dtype: jnp.dtype kv_dtype: jnp.dtype
qkv_layout: QKVLayout qkv_layout: QKVLayout
...@@ -123,6 +124,7 @@ class FusedAttnHelper: ...@@ -123,6 +124,7 @@ class FusedAttnHelper:
def get_fused_attn_backend(self): def get_fused_attn_backend(self):
"""Get the fused attention kernel backend""" """Get the fused attention kernel backend"""
return transformer_engine_jax.get_fused_attn_backend( return transformer_engine_jax.get_fused_attn_backend(
self.is_training,
jax_dtype_to_te_dtype(self.q_dtype), jax_dtype_to_te_dtype(self.q_dtype),
jax_dtype_to_te_dtype(self.kv_dtype), jax_dtype_to_te_dtype(self.kv_dtype),
self.qkv_layout.value, self.qkv_layout.value,
...@@ -276,6 +278,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -276,6 +278,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
# backend determines the softmax buffer shape/dtype # backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper( backend = FusedAttnHelper(
config.is_training,
q_dtype, q_dtype,
k_dtype, k_dtype,
config.qkv_layout, config.qkv_layout,
......
...@@ -96,7 +96,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler); ...@@ -96,7 +96,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads, size_t q_num_heads, size_t kv_num_heads,
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_attn_heads, size_t kv_attn_heads, size_t q_attn_heads, size_t kv_attn_heads,
...@@ -19,9 +19,9 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, ...@@ -19,9 +19,9 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
size_t head_dim, int64_t window_size_left, size_t head_dim, int64_t window_size_left,
int64_t window_size_right) { int64_t window_size_right) {
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type, is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right); kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right);
return backend; return backend;
} }
...@@ -263,9 +263,9 @@ static void FusedAttnForwardImpl( ...@@ -263,9 +263,9 @@ static void FusedAttnForwardImpl(
/* Prepare RNG state */ /* Prepare RNG state */
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64); auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type, is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right); kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right);
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */ /* Auxiliary tensors (to be propagated to the backward pass later) */
...@@ -518,9 +518,9 @@ static void FusedAttnBackwardImpl( ...@@ -518,9 +518,9 @@ static void FusedAttnBackwardImpl(
NVTETensorPack aux_input_tensors; NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors); nvte_tensor_pack_create(&aux_input_tensors);
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type, is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right); kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias); softmax_aux, rng_state, bias);
......
...@@ -596,6 +596,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -596,6 +596,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
seqlen_kv = key.shape[sequence_dim] seqlen_kv = key.shape[sequence_dim]
has_fused_attn_kernel = is_fused_attn_kernel_available( has_fused_attn_kernel = is_fused_attn_kernel_available(
# This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode.
not deterministic,
self.dtype, self.dtype,
self.dtype, self.dtype,
qkv_layout, qkv_layout,
......
...@@ -761,6 +761,7 @@ def get_attention_backend( ...@@ -761,6 +761,7 @@ def get_attention_backend(
q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
kv_type = q_type kv_type = q_type
fused_attention_backend = tex.get_fused_attn_backend( fused_attention_backend = tex.get_fused_attn_backend(
is_training,
q_type, q_type,
kv_type, kv_type,
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
......
...@@ -35,13 +35,11 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T ...@@ -35,13 +35,11 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
* Attention * Attention
**************************************************************************************************/ **************************************************************************************************/
NVTE_Fused_Attn_Backend get_fused_attn_backend(const DType q_dtype, const DType kv_dtype, NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type attn_mask_type, float p_dropout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads,
size_t num_attn_heads, size_t num_gqa_groups, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right);
size_t head_dim_qk, size_t head_dim_v,
int64_t window_size_left, int64_t window_size_right);
std::vector<py::object> fused_attn_fwd( std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
......
...@@ -57,14 +57,14 @@ namespace transformer_engine::pytorch { ...@@ -57,14 +57,14 @@ namespace transformer_engine::pytorch {
// get the fused attention backend // get the fused attention backend
NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Fused_Attn_Backend get_fused_attn_backend(
const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
int64_t window_size_left, int64_t window_size_right) { size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) {
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type, is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q,
head_dim_qk, head_dim_v, window_size_left, window_size_right); max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right);
return fused_attention_backend; return fused_attention_backend;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment