Unverified Commit 696ad6c4 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[Common/PyTorch] Fix FP8 fused attention input args (#592)



fix FP8 dims
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent f2bd53c4
...@@ -842,7 +842,7 @@ param_types_fp8 = [torch.float16] ...@@ -842,7 +842,7 @@ param_types_fp8 = [torch.float16]
@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.") @pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability != (9, 0), reason="FP8 tests require Hopper.") @pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.parametrize("dtype", param_types_fp8) @pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", model_configs_fp8.keys()) @pytest.mark.parametrize("model", model_configs_fp8.keys())
def test_dpa_fp8(dtype, model): def test_dpa_fp8(dtype, model):
......
...@@ -1862,8 +1862,7 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in ...@@ -1862,8 +1862,7 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV // fused attention FWD FP8 with packed QKV
void fused_attn_fp8_fwd_qkvpacked( void fused_attn_fp8_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t b, size_t h, size_t max_seqlen, size_t d,
size_t h, size_t d,
bool is_training, float attn_scale, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_QKV, const Tensor *input_QKV,
...@@ -1960,8 +1959,7 @@ void fused_attn_fp8_fwd_qkvpacked( ...@@ -1960,8 +1959,7 @@ void fused_attn_fp8_fwd_qkvpacked(
} }
// fused attention BWD FP8 with packed QKV // fused attention BWD FP8 with packed QKV
void fused_attn_fp8_bwd_qkvpacked( void fused_attn_fp8_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t b, size_t h, size_t max_seqlen, size_t d,
size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_QKV, const Tensor *input_QKV,
const Tensor *input_O, const Tensor *input_O,
...@@ -2055,8 +2053,7 @@ void fused_attn_fp8_bwd_qkvpacked( ...@@ -2055,8 +2053,7 @@ void fused_attn_fp8_bwd_qkvpacked(
} }
// fused attention FWD FP8 with separate Q, K, V // fused attention FWD FP8 with separate Q, K, V
void fused_attn_fp8_fwd( void fused_attn_fp8_fwd(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d,
size_t h, size_t d,
bool is_training, float attn_scale, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q, const Tensor *input_Q,
...@@ -2156,8 +2153,7 @@ void fused_attn_fp8_fwd( ...@@ -2156,8 +2153,7 @@ void fused_attn_fp8_fwd(
} }
// fused attention BWD FP8 with separate Q, K, V // fused attention BWD FP8 with separate Q, K, V
void fused_attn_fp8_bwd( void fused_attn_fp8_bwd(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d,
size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_Q, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_K,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment