Commit 466b82a5 authored by danyao12's avatar danyao12
Browse files

add data type config to FAv3

parent cd4d4629
...@@ -301,108 +301,108 @@ struct fmha_bwd_dq_dk_dv_v3_traits_ ...@@ -301,108 +301,108 @@ struct fmha_bwd_dq_dk_dv_v3_traits_
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Name; template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Name;
// ########################################################|HDim| DataType|kIsCausal|kIsAtomic32|kIsSpec|BF16Cvt| // ########################################################|HDim| DataType|kIsCausal|kIsAtomic32|kIsSpec|BF16Cvt|
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtna"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtne"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtz"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtna"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtne"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtz"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, true, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_spec_a32"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, true, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_spec_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, true, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_spec_causal_a32"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, true, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_spec_causal_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, false, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_a16"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_a16"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_a32"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, false, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_causal_a16"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_causal_a16"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_causal_a32"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_causal_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, true, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_spec_a32"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, true, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_spec_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, true, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_spec_causal_a32"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, true, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_spec_causal_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, false, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, false, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, false, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, true, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, true, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, true, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, false, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, false, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, false, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, true, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, true, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, true, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz"; }}; template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz"; }};
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Buf; template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Buf;
// #######################################################|HDim| DataType|kIsCausal|kIsAtomic32|kIsSpec|BF16Cvt| // #######################################################|HDim| DataType|kIsCausal|kIsAtomic32|kIsSpec|BF16Cvt|
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtna; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtne; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtz; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtna; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtne; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtz; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtna; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtne; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtz; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtna; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtne; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtz; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, true, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_spec_a32; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, true, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_spec_a32; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, true, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_spec_causal_a32; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, true, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_spec_causal_a32; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, false, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_a16; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_a16; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_a32; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_a32; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, false, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_causal_a16; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_causal_a16; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_causal_a32; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_causal_a32; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, true, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_spec_a32; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, true, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_spec_a32; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, true, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_spec_causal_a32; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, true, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_spec_causal_a32; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, false, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtna; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, false, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtne; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, false, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtz; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, true, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtna; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, true, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtne; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, true, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtz; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, false, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtna; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, false, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtne; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, false, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtz; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, true, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtna; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, true, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtne; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, true, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtz; }}; template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtz; }};
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Ts; template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Ts;
// ######################################################|HDim| DataType|kIsCausal|kIsAtomic32|kIsSpec|BF16Cvt| // ######################################################|HDim| DataType|kIsCausal|kIsAtomic32|kIsSpec|BF16Cvt|
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 1>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, false, 1>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 2>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, false, 2>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 1>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, false, 1>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 2>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, false, 2>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 1>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, false, 1>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 2>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, false, 2>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 1>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, false, 1>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 2>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, false, 2>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, true, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 128; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, true, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 128; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, true, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 128; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, true, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 128; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, false, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, false, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, true, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 128; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, true, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 128; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, true, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 128; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, true, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 128; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, false, false, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, false, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, false, false, 1>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, false, 1>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, false, false, 2>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, false, 2>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, true, false, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, false, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, true, false, 1>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, false, 1>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, true, false, 2>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, false, 2>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, false, false, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, false, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, false, false, 1>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, false, 1>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, false, false, 2>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, false, 2>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, true, false, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, false, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, true, false, 1>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, false, 1>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, true, false, 2>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, false, 2>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
class fmha_bwd_v3_kernel class fmha_bwd_v3_kernel
{{ {{
...@@ -643,26 +643,26 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -643,26 +643,26 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
if(t.mask_type == mask_enum::no_mask){{ if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{ if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{ if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, true, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, true, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_spec_a32"; // const std::string bwd_v3_name = "bwd_v3_fp16_spec_a32";
bool io_perm = a.nhead_stride_q > a.stride_q; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r; return r;
}} }}
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{ else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, false, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a32"; // const std::string bwd_v3_name = "bwd_v3_fp16_a32";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
}} }}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, false, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, false, 0>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a16"; // const std::string bwd_v3_name = "bwd_v3_fp16_a16";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
...@@ -671,26 +671,26 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -671,26 +671,26 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{ if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{ if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, true, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, true, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_spec_causal_a32"; // const std::string bwd_v3_name = "bwd_v3_fp16_spec_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r; return r;
}} }}
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{ else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, false, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32"; // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
}} }}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, false, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, false, 0>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16"; // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
...@@ -701,9 +701,9 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -701,9 +701,9 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
if(t.mask_type == mask_enum::no_mask){{ if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{ if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{ if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, true, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, true, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_spec_a32"; // const std::string bwd_v3_name = "bwd_v3_bf16_spec_a32";
bool io_perm = a.nhead_stride_q > a.stride_q; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
...@@ -711,25 +711,25 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -711,25 +711,25 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
}} }}
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{ else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{ if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, false, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne"; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 1>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, false, 1>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna"; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 2>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, false, 2>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz"; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
...@@ -738,22 +738,22 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -738,22 +738,22 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
}} }}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{ if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, false, 0>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne"; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 1>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, false, 1>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna"; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 2>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, false, 2>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz"; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
...@@ -763,9 +763,9 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -763,9 +763,9 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{ if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{ if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, true, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, true, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_spec_causal_a32"; // const std::string bwd_v3_name = "bwd_v3_bf16_spec_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
...@@ -773,25 +773,25 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -773,25 +773,25 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
}} }}
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{ else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{ if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, false, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 1>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, false, 1>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 2>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, false, 2>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
...@@ -800,22 +800,22 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -800,22 +800,22 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
}} }}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{ if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, false, 0>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 1>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, false, 1>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 2>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, false, 2>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
...@@ -829,25 +829,25 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -829,25 +829,25 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
if(t.mask_type == mask_enum::no_mask){{ if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{ if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.how_v3_bf16_cvt == 0){{ if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, true, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, false, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, true, false, 1>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, false, 1>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, true, false, 2>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, false, 2>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
...@@ -855,22 +855,22 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -855,22 +855,22 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
}} }}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{ if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, false, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, false, 0>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, false, false, 1>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, false, 1>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, false, false, 2>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, false, 2>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
...@@ -880,25 +880,25 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -880,25 +880,25 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{ if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.how_v3_bf16_cvt == 0){{ if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, true, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, false, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, true, false, 1>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, false, 1>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, true, false, 2>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, false, 2>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
...@@ -906,22 +906,22 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -906,22 +906,22 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
}} }}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{ if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, false, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, false, 0>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne"; const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, false, false, 1>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, false, 1>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, false, false, 2>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, false, 2>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment