Commit 8b745f2c authored by Jim's avatar Jim
Browse files

codegen template

parent f4489897
...@@ -126,3 +126,9 @@ BOOL_MAP = { ...@@ -126,3 +126,9 @@ BOOL_MAP = {
"t" : "true", "t" : "true",
"f" : "false" "f" : "false"
} }
BF16_CVT_MAP = {
0 : "rtne",
1 : "rtna",
2 : "rtz",
}
\ No newline at end of file
...@@ -161,6 +161,13 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>() ...@@ -161,6 +161,13 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
""" """
FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp"
FMHA_BWD_V3_TEMPLATE="""
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_{F_hdim_name}_{F_dtype_name}_{F_causal_name}_{F_atomic_name}_{F_bf16_cvt_name}"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_{F_hdim_name}_{F_dtype_name}_{F_causal_name}_{F_atomic_name}_{F_bf16_cvt_name}; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr int ts_qo = {F_Ts_qo}; static constexpr int ts_kv = 192; }};
"""
FMHA_BWD_API=""" FMHA_BWD_API="""
#include <iostream> #include <iostream>
#include "hsaco/fmha_hsaco.hpp" #include "hsaco/fmha_hsaco.hpp"
...@@ -311,158 +318,7 @@ struct fmha_bwd_dq_dk_dv_v3_traits_ ...@@ -311,158 +318,7 @@ struct fmha_bwd_dq_dk_dv_v3_traits_
static constexpr bool kIsHDPad = kIsHDPad_; static constexpr bool kIsHDPad = kIsHDPad_;
}}; }};
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Name; {F_template}
// ########################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsHDPad|
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_a16"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_causal_a16"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_causal_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtne_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtna_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtz_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtne_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtna_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtz_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_a16_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_a32_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_causal_a16_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_causal_a32_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 1, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 2, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 1, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 2, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 1, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 2, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 1, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 2, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, false, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_a16"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, false, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_causal_a16"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32"; }};
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Buf;
// #######################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsHDPad|
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_a16; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_a32; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_causal_a16; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_causal_a32; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtne_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtna_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtz_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtne_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtna_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtz_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtne_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtna_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtz_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtne_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtna_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtz_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_a16_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_a32_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_causal_a16_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_causal_a32_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 1, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 2, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 1, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 2, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 1, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 2, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 1, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 2, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, false, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_a16; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_a32; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, false, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_causal_a16; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_causal_a32; }};
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Ts;
// ######################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsHDPad|
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 0, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 1, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 2, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 0, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 1, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 2, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 0, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 1, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 2, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 0, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 1, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 2, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, false, 0, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, false, 0, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
class fmha_bwd_v3_kernel class fmha_bwd_v3_kernel
{{ {{
...@@ -732,7 +588,17 @@ float fmha_bwd_v3_hdp_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a) ...@@ -732,7 +588,17 @@ float fmha_bwd_v3_hdp_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
float r = -1; float r = -1;
if ((t.uses_bwd_v3 == true){{
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && (t.is_deterministic == false) && (a.hdim_q == a.hdim_v) &&
(a.seqlen_q == a.seqlen_k) && (a.nhead_q % a.nhead_k == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
(a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))) {{
if (((a.hdim_q >= 64) && (a.hdim_q <= 128) && (a.hdim_q % 8 == 0) && (a.seqlen_k % 64 == 0))) {{
{F_v3_dispatch} {F_v3_dispatch}
}}
}}
}}
{F_dispatch} {F_dispatch}
return r; return r;
}} }}
...@@ -757,23 +623,18 @@ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) ...@@ -757,23 +623,18 @@ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode})
}} }}
""" """
FMHA_V3_DISPATCH=""" if ((t.uses_bwd_v3 == {F_use_v3}){{ FMHA_V3_DISPATCH="""
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && (t.is_deterministic == false) && (a.hdim_q == a.hdim_v) &&
(a.seqlen_q == a.seqlen_k) && (a.nhead_q % a.nhead_k == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
(a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))) {{
if (((a.hdim_q >= 64) && (a.hdim_q <= 128) && (a.hdim_q % 8 == 0) && (a.seqlen_k % 64 == 0))) {{
if (t.mask_type == mask_enum::no_mask){{ if (t.mask_type == mask_enum::no_mask){{
if ((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{ if ((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, false, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}, false>;
r = fmha_bwd_v3{F_padding_suffix}_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3{F_padding_suffix}_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else if (t.is_v3_atomic_fp32 == false){{ else if (t.is_v3_atomic_fp32 == false){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, false, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>;
r = fmha_bwd_v3{F_padding_suffix}_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3{F_padding_suffix}_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
...@@ -781,21 +642,19 @@ FMHA_V3_DISPATCH=""" if ((t.uses_bwd_v3 == {F_use_v3}){{ ...@@ -781,21 +642,19 @@ FMHA_V3_DISPATCH=""" if ((t.uses_bwd_v3 == {F_use_v3}){{
else if ((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ else if ((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if ((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{ if ((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, true, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}, false>;
r = fmha_bwd_v3{F_padding_suffix}_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3{F_padding_suffix}_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else if (t.is_v3_atomic_fp32 == false){{ else if (t.is_v3_atomic_fp32 == false){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, true, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>;
r = fmha_bwd_v3{F_padding_suffix}_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3{F_padding_suffix}_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
}} }}
}} }}
}}
}}
""" """
@dataclass @dataclass
...@@ -818,9 +677,6 @@ class FmhaBwdDQDKDVApiTrait: ...@@ -818,9 +677,6 @@ class FmhaBwdDQDKDVApiTrait:
dpad : str dpad : str
dvpad : str dvpad : str
deterministic : str deterministic : str
uses_bwd_v3 : str
v3_atomic_fp32 : str
how_v3_bf16_cvt : int
def scheck(self, spad1 : str) -> str: def scheck(self, spad1 : str) -> str:
if self.mode == 'group': if self.mode == 'group':
...@@ -851,6 +707,15 @@ class FmhaBwdDQDKDVApiTrait: ...@@ -851,6 +707,15 @@ class FmhaBwdDQDKDVApiTrait:
if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0'
else : return f'a.hdim_v % {self.bhdv} == 0' else : return f'a.hdim_v % {self.bhdv} == 0'
@dataclass
class FmhaBwdV3DQDKDVApiTrait:
hdim : str
dtype : str # data type
is_causal : str
is_atomic : str
bf16_cvt : int
hd_pad : str
class FmhaBwdApiPool: class FmhaBwdApiPool:
def __init__(self, mask_impl): def __init__(self, mask_impl):
self.dq_dk_dv_pool = dict() self.dq_dk_dv_pool = dict()
...@@ -885,7 +750,6 @@ class FmhaBwdApiPool: ...@@ -885,7 +750,6 @@ class FmhaBwdApiPool:
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype],
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic]) F_deterministic=BOOL_MAP[trait.deterministic])
if_j = 'if' if j == 0 else 'else if' if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if' if_i = 'if' if i == 0 else 'else if'
...@@ -893,20 +757,30 @@ class FmhaBwdApiPool: ...@@ -893,20 +757,30 @@ class FmhaBwdApiPool:
if not per_dtypes: if not per_dtypes:
# empty string we add some ignore to suppress warning in api # empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;' per_dtypes += ' (void)t ; (void)s ; (void)a;'
# v3 codegen:
v3_code = str() v3_code = str()
gen_template = str()
for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): for i, dtype in enumerate(self.dq_dk_dv_pool.keys()):
per_hdim_case=str()
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()):
traits = self.dq_dk_dv_pool[dtype][hdim] traits = self.dq_dk_dv_pool[dtype][hdim]
hdim_int = int(hdim) hdim_int = int(hdim)
hdim_int = (hdim_int + 64 - 1) / 64 * 64 hdim_int = (hdim_int + 64 - 1) / 64 * 64
Ts_qo = 32 if hdim == 64 else 16
for k, trait in enumerate(traits): for k, trait in enumerate(traits):
padding = "t" if hdim_int % 64 == 0 else "f" padding = "t" if hdim_int % 64 == 0 else "f"
padding_suffix = "_hdp" if padding == "t" else "" padding_suffix = "_hdp" if padding == "t" else ""
v3_code = v3_code + FMHA_V3_DISPATCH.format(F_use_v3=BOOL_MAP[trait.use_bwd_v3], F_hdim=hdim_int, F_dtype=BWD_DTYPE_MAP[dtype], F_padding=BOOL_MAP[padding], v3_code = v3_code + FMHA_V3_DISPATCH.format(F_hdim=hdim_int, F_dtype=BWD_DTYPE_MAP[dtype], F_padding=BOOL_MAP[padding],
F_is_atomic32=BOOL_MAP[trait.v3_atomic_fp32], F_how_v3_bf16_cvt=trait.how_v3_bf16_cvt, F_padding_suffix=padding_suffix) F_is_atomic32=BOOL_MAP[trait.is_atomic], F_how_v3_bf16_cvt=trait.bf16_cvt, F_padding_suffix=padding_suffix)
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes, F_v3_dispatch = v3_code)
hdim_name = "hd64" if hdim_int == 64 else ""
dtype_name = dtype
causal_name = "causal" if trait.is_causal == "t" else ""
atomic_name = "a32" if trait.is_atomic == "t" else "a16"
bf16_cvt_name = BF16_CVT_MAP[trait.bf16_cvt]
gen_template = gen_template + FMHA_BWD_V3_TEMPLATE.format(F_hdim=hdim_int, F_dtype=BWD_DTYPE_MAP[dtype], F_is_atomic=BOOL_MAP[trait.is_atomic],
F_is_causal=BOOL_MAP[trait.is_causal], F_bf16_cvt=trait.bf16_cvt, F_hdpad=BOOL_MAP[padding], F_Ts_qo = Ts_qo, F_hdim_name=hdim_name,
F_dtype_name=dtype_name, F_causal_name=causal_name, F_atomic_name=atomic_name, F_bf16_cvt_name=bf16_cvt_name)
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes, F_template = gen_template, F_v3_dispatch = v3_code)
# GEMM0: Q@K=S^T # GEMM0: Q@K=S^T
# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) # GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v)
...@@ -1038,7 +912,6 @@ class FmhaBwdDQDKDVKernel: ...@@ -1038,7 +912,6 @@ class FmhaBwdDQDKDVKernel:
def filename(self) -> str: def filename(self) -> str:
return self.name + ".cpp" return self.name + ".cpp"
# TODO: Jim: update this
def api_trait(self) -> FmhaBwdDQDKDVApiTrait: def api_trait(self) -> FmhaBwdDQDKDVApiTrait:
return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline, return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline,
hdim=str(self.F_hdim), hdim=str(self.F_hdim),
...@@ -1056,10 +929,7 @@ class FmhaBwdDQDKDVKernel: ...@@ -1056,10 +929,7 @@ class FmhaBwdDQDKDVKernel:
skpad=self.F_skpad, skpad=self.F_skpad,
dpad=self.F_dpad, dpad=self.F_dpad,
dvpad=self.F_dvpad, dvpad=self.F_dvpad,
deterministic=self.F_deterministic, deterministic=self.F_deterministic
use_bwd_v3=self.use_bwd_v3,
v3_atomic_fp32=self.v3_atomic_fp32,
how_v3_bf16_cvt=self.how_v3_bf16_cvt
) )
# TODO: design a more practical way to do it # TODO: design a more practical way to do it
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment