"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "a69692ac0ba441c109096b7fe376d5a1bef3c78e"
Unverified Commit 71c76b6b authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Add support for head_dim > 128 (#1797)



* add support for head dim > 128
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* remove debugging
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* raise tols slightly to tolerate 1/2048 mismatches
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix is_training for test_te_layer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add bprop support for blackwell
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor tweak for format
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix backend selection results
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* bump sm100 to sm100+
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add sq=1 test for MLA
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* enable sq=1 for bprop
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* minor tweak in comments
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix head_dim logic and remove pytest skip
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add FE fix for d>128
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* update FE again to take in small fixes
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add cuDNN version info in L0 tests
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* increase tols for Unfused + large dim
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* Revert "add cuDNN version info in L0 tests"

This reverts commit 3e1b426ca5319a2c0540b9e73bba7047d0e583e5.
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix tols for Unfused
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

---------
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 7b94bd99
Subproject commit 724f0ec8ce06027feada51f2d948cd3313e63720
Subproject commit f937055efc6d414d11f4c6577e3977fe74f35fb6
......@@ -68,6 +68,7 @@ class TestDistributedSelfAttn:
batch, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(
is_training,
dtype,
dtype,
QKVLayout.BS3HD,
......@@ -214,6 +215,7 @@ class TestDistributedCrossAttn:
batch, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(
is_training,
dtype,
dtype,
QKVLayout.BSHD_BS2HD,
......@@ -346,6 +348,7 @@ class TestDistributedContextParallelSelfAttn:
def check_has_backend_for_mask(mask_type):
return is_fused_attn_kernel_available(
is_training,
dtype,
dtype,
qkv_layout,
......
......@@ -347,6 +347,7 @@ class FusedAttnRunner:
)
self.backend = FusedAttnHelper(
self.is_training,
self.dtype,
self.dtype,
self.qkv_layout,
......
......@@ -222,13 +222,19 @@ def _get_attention_backends(
model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0
"base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0
"base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
"base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
"base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
# test: b, h, hg, d, sq, skv, p, mask, bias
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
"base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_4_0": ModelConfig(8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_4_1": ModelConfig(8, 16, 16, 192, 128, 2048, 0.0, "no_mask", "no_bias"),
"base_5_0": ModelConfig(8, 16, 16, 512, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_5_1": ModelConfig(8, 16, 16, 512, 128, 2048, 0.0, "no_mask", "no_bias"),
"base_6_0": ModelConfig(8, 16, 16, 1024, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_6_1": ModelConfig(8, 16, 16, 1024, 128, 2048, 0.0, "no_mask", "no_bias"),
}
......@@ -270,14 +276,28 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
is_training = True
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
is_training = False
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
# mannually pads and unpads the input and output of FlashAttention for testing purposes
......@@ -296,7 +316,6 @@ def test_dot_product_attention(
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128
# UnfusedDotProductAttention backend
if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
......@@ -360,6 +379,7 @@ def test_dot_product_attention(
is_training,
)
logging.info(f"[test_dot_product_attention]: is_training = {is_training}")
if unfused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
......@@ -399,18 +419,27 @@ model_configs_mla = {
"mla_1_1": ModelConfig(
4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_1_2": ModelConfig(
4, 16, 16, 192, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_2_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64
), # self , 1
"mla_2_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # cross, 1
"mla_2_2": ModelConfig(
1, 24, 24, 192, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=128
), # cross, 1
"mla_3_0": ModelConfig(
8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64
), # inference
"mla_3_1": ModelConfig(
8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference
"mla_3_2": ModelConfig(
8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference
}
......@@ -1024,6 +1053,8 @@ def _run_dot_product_attention(
layer_number=1,
attention_type=config.attn_type,
).to(dtype=dtype, device="cuda")
if not is_training:
block = block.eval()
# Run a forward and backward pass
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
......@@ -1136,14 +1167,29 @@ def test_transformer_layer(
workspace_opt = True
# Test backend availability
is_training = True
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=(
qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
),
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
is_training = False
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=(
qkv_format.replace("hd", "h3d")
if fused_qkv_params
else qkv_format.replace("hd", "3hd")
),
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
......@@ -1163,6 +1209,7 @@ def test_transformer_layer(
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)
# FusedAttention backend
......@@ -1176,6 +1223,7 @@ def test_transformer_layer(
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)
# FlashAttention backend
......@@ -1189,8 +1237,10 @@ def test_transformer_layer(
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)
logging.info(f"[test_transformer_layer]: is_training = {is_training}")
if unfused_attn_supported and fused_attn_supported:
logging.info("[test_transformer_layer]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
......@@ -1257,6 +1307,7 @@ def _run_transformer_layer(
workspace_opt: bool,
fused_qkv_params: bool,
RoPE: bool,
is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run TransformerLayer module with one forward pass and one backward pass"""
......@@ -1410,6 +1461,8 @@ def _run_transformer_layer(
bias=True,
attn_input_format=qkv_format,
).to(dtype=dtype, device="cuda")
if not is_training:
block = block.eval()
# Create ALiBi slopes
alibi_slopes = None
......@@ -1432,8 +1485,9 @@ def _run_transformer_layer(
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
)
loss = out.sum()
loss.backward()
if is_training:
loss = out.sum()
loss.backward()
return out, inp.grad
......
......@@ -52,7 +52,7 @@ model_configs_infer = {
4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16
),
"infer_1": ModelConfig(
2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
),
}
......@@ -370,12 +370,24 @@ def generate_args(
]
def get_tols(module, backend, dtype):
def get_tols(config, module, backend, dtype):
if module == "TransformerLayer":
tols = {
torch.half: (5e-3, 5e-3),
torch.bfloat16: (3.5e-2, 3.5e-2),
}
if config.head_dim_qk <= 128:
tols = {
torch.half: (5e-3, 5e-3),
torch.bfloat16: (3.5e-2, 3.5e-2),
}
else:
if backend == "UnfusedAttention":
tols = {
torch.half: (1.6e-2, 1.6e-2),
torch.bfloat16: (1.2e-1, 1e-1),
}
else:
tols = {
torch.half: (1e-2, 1e-2),
torch.bfloat16: (8e-2, 7e-2),
}
if module == "DotProductAttention":
tols = {
torch.half: (1e-3, 1e-3),
......@@ -662,7 +674,9 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
incremental_output = incremental_output[0]
# compare results
atol, rtol = get_tols(module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn)
atol, rtol = get_tols(
config, module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn
)
for i, seq in enumerate(sim.t_seq_ids):
token_index = sim.step_lens[i] - 1
if qkv_format == "bshd":
......
......@@ -134,10 +134,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {
// select a backend for fused attention
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v,
int64_t window_size_left, int64_t window_size_right) {
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) {
using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device();
......@@ -216,12 +216,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
}
if (
// TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging
// special conditions for blackwell
// TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7
!(sm_arch_ >= 100 && (head_dim_qk > 128 || head_dim_v > 128)) &&
// architecture
((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) ||
(cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) &&
((cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90)) ||
(cudnn_runtime_version >= 8903 && sm_arch_ >= 80 && sm_arch_ < 100) ||
(cudnn_runtime_version >= 90700 && sm_arch_ >= 80)) &&
// sequence length
((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) ||
(cudnn_runtime_version >= 90000)) &&
......@@ -229,11 +227,28 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) ||
(cudnn_runtime_version >= 8907)) &&
// head dimension
((head_dim_qk <= 128 && head_dim_qk % 8 == 0 && head_dim_v <= 128 && head_dim_v % 8 == 0) ||
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
// d=256 only supported for forward
(sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim_qk <= 256 &&
head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) &&
// multiples of 8
(head_dim_qk % 8 == 0 && head_dim_v % 8 == 0 &&
// <= 128
((head_dim_qk <= 128 && head_dim_v <= 128) ||
// 9.1: <= 256 + Hopper + fprop
// 9.5: <= 256 + Hopper + bprop
(head_dim_qk <= 256 && head_dim_v <= 256 &&
((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) ||
(is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) ||
// 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1
(!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 &&
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) ||
// 9.10: any head_dim + any arch + fprop + paged
// 9.10: any head_dim + any arch + fprop + non_paged + sq > 1
// 9.10: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM}
(!is_training && cudnn_runtime_version >= 91000 &&
(layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 ||
(max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) ||
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
(head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 &&
cudnn_runtime_version >= 91100))) &&
// bias type
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
(cudnn_runtime_version >= 8906 &&
......@@ -423,8 +438,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
max_seqlen, d, d, window_size_left, window_size_right);
is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h,
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -505,7 +520,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
......@@ -636,8 +651,8 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, d, window_size_left, window_size_right);
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -731,8 +746,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, d, window_size_left, window_size_right);
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -862,8 +877,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -954,8 +969,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......
......@@ -172,6 +172,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get fused attention backend based on input parameters.
*
* \param[in] is_training Whether the model is in training mode.
* \param[in] q_dtype The data type of Tensor Q.
* \param[in] kv_dtype The data type of Tensors K, V.
* \param[in] qkv_layout The layout of Tensors Q, K, V.
......@@ -188,10 +189,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] window_size_right Sliding window size (the right half).
*/
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v,
int64_t window_size_left, int64_t window_size_right);
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right);
/*! \brief Compute dot product attention with packed QKV input.
*
......
......@@ -277,6 +277,7 @@ def canonicalize_attn_mask_type(attn_mask_type: str):
def is_fused_attn_kernel_available(
is_training,
q_dtype,
kv_dtype,
qkv_layout,
......@@ -296,6 +297,7 @@ def is_fused_attn_kernel_available(
def make_helper(attn_mask_type):
return tex.FusedAttnHelper(
is_training,
q_dtype,
kv_dtype,
qkv_layout,
......
......@@ -103,6 +103,7 @@ class FusedAttnHelper:
Helper for the fused attention backend
"""
is_training: bool
q_dtype: jnp.dtype
kv_dtype: jnp.dtype
qkv_layout: QKVLayout
......@@ -123,6 +124,7 @@ class FusedAttnHelper:
def get_fused_attn_backend(self):
"""Get the fused attention kernel backend"""
return transformer_engine_jax.get_fused_attn_backend(
self.is_training,
jax_dtype_to_te_dtype(self.q_dtype),
jax_dtype_to_te_dtype(self.kv_dtype),
self.qkv_layout.value,
......@@ -276,6 +278,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
# backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(
config.is_training,
q_dtype,
k_dtype,
config.qkv_layout,
......
......@@ -96,7 +96,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads,
......
......@@ -11,7 +11,7 @@
namespace transformer_engine {
namespace jax {
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_attn_heads, size_t kv_attn_heads,
......@@ -19,9 +19,9 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
size_t head_dim, int64_t window_size_left,
int64_t window_size_right) {
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right);
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen,
kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right);
return backend;
}
......@@ -263,9 +263,9 @@ static void FusedAttnForwardImpl(
/* Prepare RNG state */
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right);
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen,
kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right);
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */
......@@ -518,9 +518,9 @@ static void FusedAttnBackwardImpl(
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right);
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen,
kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias);
......
......@@ -596,6 +596,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
seqlen_kv = key.shape[sequence_dim]
has_fused_attn_kernel = is_fused_attn_kernel_available(
# This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode.
not deterministic,
self.dtype,
self.dtype,
qkv_layout,
......
......@@ -761,6 +761,7 @@ def get_attention_backend(
q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
kv_type = q_type
fused_attention_backend = tex.get_fused_attn_backend(
is_training,
q_type,
kv_type,
QKVLayout[qkv_layout],
......
......@@ -35,13 +35,11 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
* Attention
**************************************************************************************************/
NVTE_Fused_Attn_Backend get_fused_attn_backend(const DType q_dtype, const DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float p_dropout,
size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv,
size_t head_dim_qk, size_t head_dim_v,
int64_t window_size_left, int64_t window_size_right);
NVTE_Fused_Attn_Backend get_fused_attn_backend(
bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right);
std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
......
......@@ -57,14 +57,14 @@ namespace transformer_engine::pytorch {
// get the fused attention backend
NVTE_Fused_Attn_Backend get_fused_attn_backend(
const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v,
int64_t window_size_left, int64_t window_size_right) {
bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) {
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv,
head_dim_qk, head_dim_v, window_size_left, window_size_right);
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q,
max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right);
return fused_attention_backend;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment