Unverified Commit 696ad6c4 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[Common/PyTorch] Fix FP8 fused attention input args (#592)



fix FP8 dims
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent f2bd53c4
......@@ -842,7 +842,7 @@ param_types_fp8 = [torch.float16]
@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability != (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", model_configs_fp8.keys())
def test_dpa_fp8(dtype, model):
......
......@@ -1862,8 +1862,7 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
#if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV
void fused_attn_fp8_fwd_qkvpacked(
size_t b, size_t max_seqlen,
size_t h, size_t d,
size_t b, size_t h, size_t max_seqlen, size_t d,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_QKV,
......@@ -1960,8 +1959,7 @@ void fused_attn_fp8_fwd_qkvpacked(
}
// fused attention BWD FP8 with packed QKV
void fused_attn_fp8_bwd_qkvpacked(
size_t b, size_t max_seqlen,
size_t h, size_t d,
size_t b, size_t h, size_t max_seqlen, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_QKV,
const Tensor *input_O,
......@@ -2055,8 +2053,7 @@ void fused_attn_fp8_bwd_qkvpacked(
}
// fused attention FWD FP8 with separate Q, K, V
void fused_attn_fp8_fwd(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t h, size_t d,
size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q,
......@@ -2156,8 +2153,7 @@ void fused_attn_fp8_fwd(
}
// fused attention BWD FP8 with separate Q, K, V
void fused_attn_fp8_bwd(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t h, size_t d,
size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q,
const Tensor *input_K,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment