Commit 43d81903 authored by danyao12's avatar danyao12
Browse files

add templates

parent 2defe2f6
......@@ -283,6 +283,127 @@ struct fmha_bwd_v3_traits
int ts_kv;
}};
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsCausal_,
bool kIsAtomic32_,
bool kIsSpec_,
ck_tile::index_t BF16Cvt_>
struct fmha_bwd_dq_dk_dv_v3_traits_
{{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsCausal = kIsCausal_;
static constexpr bool kIsAtomic32 = kIsAtomic32_;
static constexpr bool kIsSpec = kIsSpec_;
static constexpr ck_tile::index_t BF16Cvt = BF16Cvt_;
}};
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Name;
// ########################################################|HDim| DataType|kIsCausal|kIsAtomic32|kIsSpec|BF16Cvt|
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, true, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_spec_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, true, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_spec_causal_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, false, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_a16"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, false, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_causal_a16"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_causal_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, true, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_spec_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, true, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_spec_causal_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, false, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, false, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, false, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, true, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, true, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, true, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, false, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, false, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, false, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, true, false, 0>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, true, false, 1>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, true, false, 2>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz"; }};
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Buf;
// #######################################################|HDim| DataType|kIsCausal|kIsAtomic32|kIsSpec|BF16Cvt|
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, true, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_spec_a32; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, true, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_spec_causal_a32; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, false, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_a16; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_a32; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, false, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_causal_a16; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_causal_a32; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, true, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_spec_a32; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, true, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_spec_causal_a32; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, false, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, false, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, false, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, true, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, true, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, true, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, false, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, false, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, false, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, true, false, 0>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, true, false, 1>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, true, false, 2>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtz; }};
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Ts;
// ######################################################|HDim| DataType|kIsCausal|kIsAtomic32|kIsSpec|BF16Cvt|
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 1>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 2>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 1>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 2>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 1>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 2>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 1>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 2>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, true, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 128; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, true, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 128; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, false, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, false, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, false, 0>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, true, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 128; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, true, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 128; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, false, false, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, false, false, 1>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, false, false, 2>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, true, false, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, true, false, 1>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, false, true, false, 2>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, false, false, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, false, false, 1>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, false, false, 2>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, true, false, 0>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, true, false, 1>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, ck_tile::bf16_t, true, true, false, 2>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
class fmha_bwd_v3_kernel
{{
public:
......@@ -373,11 +494,11 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
);
}}
template <typename dot_do_o_trait_>
float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_v3_buf[], const std::string& bwd_v3_name, bool io_perm, int ts_qo, int ts_kv)
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_>
float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io_perm)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_v3_name << std::flush;
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << std::flush;
fmha_bwd_xqa_v3_args args;
args.ptr_dq = a.dq_ptr;
args.ptr_dk = a.dk_ptr;
......@@ -392,7 +513,7 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsign
args.log2e = ck_tile::log2e_v<float>;
args.seq_len = a.seqlen_q;
int stride_tg = ts_kv * a.hdim_q * 2;
int stride_tg = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * a.hdim_q * 2;
int stride_head = a.seqlen_q * a.hdim_q * 2;
int stride_batch = a.nhead_q * a.seqlen_q * a.hdim_q * 2;
int stride_seqlen = a.hdim_q * 2;
......@@ -408,7 +529,7 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsign
stride_seqlen_kv = a.nhead_k * a.hdim_q * 2;
stride_seqlen_dkv = a.nhead_q * a.hdim_q * 2;
stride_tg = ts_kv * stride_seqlen_kv;
stride_tg = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * stride_seqlen_kv;
stride_head_kv = a.hdim_q * 2;
}}
args.Ts = stride_tg;
......@@ -427,20 +548,20 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsign
a.hdim_q,
1,
a.mask_type,
ts_qo,
ts_kv}};
static fmha_bwd_v3_kernel impl(HSA_KERNEL, bwd_v3_buf); // static here is for thread safety.
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv}};
static fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}
);
}}
template <typename dot_do_o_trait_, typename convert_dq_trait_>
float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_v3_buf[], const std::string& bwd_v3_name, bool io_perm, int ts_qo, int ts_kv)
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io_perm)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
fmha_bwd_v3_args args;
args.ptr_dq = a.dq_acc_ptr;
args.ptr_dk = a.dk_ptr;
......@@ -455,14 +576,14 @@ float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned c
args.log2e = ck_tile::log2e_v<float>;
args.seq_len = a.seqlen_q;
int stride_tg = ts_kv * a.hdim_q * 2;
int stride_tg = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * a.hdim_q * 2;
int stride_head = a.seqlen_q * a.hdim_q * 2;
int stride_batch = a.nhead_q * a.seqlen_q * a.hdim_q * 2;
int stride_seqlen = a.hdim_q * 2;
if(io_perm == 0) //BSHD
{{
stride_seqlen = a.nhead_q * a.hdim_q * 2;
stride_tg = ts_kv * stride_seqlen;
stride_tg = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * stride_seqlen;
stride_head = a.hdim_q * 2;
}}
args.Ts = stride_tg;
......@@ -475,9 +596,9 @@ float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned c
a.hdim_q,
1,
a.mask_type,
ts_qo,
ts_kv}};
static fmha_bwd_v3_kernel impl(HSA_KERNEL, bwd_v3_buf); // static here is for thread safety.
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv}};
static fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},
......@@ -485,11 +606,11 @@ float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned c
);
}}
template <typename dot_do_o_trait_, typename convert_dq_trait_>
float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_v3_buf[], const std::string& bwd_v3_name, bool io_perm, int ts_qo, int ts_kv)
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io_perm)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
fmha_bwd_xqa_v3_args args;
args.ptr_dq = a.dq_acc_ptr;
args.ptr_dk = a.dk_ptr;
......@@ -504,7 +625,7 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsign
args.log2e = ck_tile::log2e_v<float>;
args.seq_len = a.seqlen_q;
int stride_tg = ts_kv * a.hdim_q * 2;
int stride_tg = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * a.hdim_q * 2;
int stride_head = a.seqlen_q * a.hdim_q * 2;
int stride_batch = a.nhead_q * a.seqlen_q * a.hdim_q * 2;
int stride_seqlen = a.hdim_q * 2;
......@@ -520,7 +641,7 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsign
stride_seqlen_kv = a.nhead_k * a.hdim_q * 2;
stride_seqlen_dkv = a.nhead_q * a.hdim_q * 2;
stride_tg = ts_kv * stride_seqlen_kv;
stride_tg = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * stride_seqlen_kv;
stride_head_kv = a.hdim_q * 2;
}}
args.Ts = stride_tg;
......@@ -539,9 +660,9 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsign
a.hdim_q,
1,
a.mask_type,
ts_qo,
ts_kv}};
static fmha_bwd_v3_kernel impl(HSA_KERNEL, bwd_v3_buf); // static here is for thread safety.
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv}};
static fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},
......@@ -561,26 +682,29 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, true, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_spec_a32";
// const std::string bwd_v3_name = "bwd_v3_fp16_spec_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_spec_a32, bwd_v3_name, io_perm, 32, 128);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r;
}}
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, false, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_a32";
// const std::string bwd_v3_name = "bwd_v3_fp16_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r;
}}
}}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_a16";
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, false, false, 0>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_a16, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r;
}}
}}
......@@ -588,26 +712,29 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, true, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_spec_causal_a32";
// const std::string bwd_v3_name = "bwd_v3_fp16_spec_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_spec_causal_a32, bwd_v3_name, io_perm, 32, 128);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r;
}}
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, false, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32";
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_causal_a32, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r;
}}
}}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16";
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, false, false, 0>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_causal_a16, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r;
}}
}}
......@@ -617,35 +744,39 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, true, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_spec_a32";
// const std::string bwd_v3_name = "bwd_v3_bf16_spec_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_spec_a32, bwd_v3_name, io_perm, 32, 128);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r;
}}
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne";
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32_rtne, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 1>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna";
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32_rtna, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 2>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz";
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32_rtz, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r;
}}
}}
......@@ -653,23 +784,26 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne";
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 0>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtne, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna";
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 1>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtna, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz";
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 2>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtz, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r;
}}
}}
......@@ -678,35 +812,39 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, true, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_spec_causal_a32";
// const std::string bwd_v3_name = "bwd_v3_bf16_spec_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_spec_causal_a32, bwd_v3_name, io_perm, 32, 128);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r;
}}
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne";
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32_rtne, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 1>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna";
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32_rtna, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 2>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz";
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32_rtz, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r;
}}
}}
......@@ -714,23 +852,26 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne";
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 0>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtne, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna";
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 1>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtna, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz";
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 2>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtz, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r;
}}
}}
......@@ -743,49 +884,55 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, true, false, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne";
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_hd64_bf16_a32_rtne, bwd_v3_name, io_perm, 32, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, true, false, 1>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna";
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_hd64_bf16_a32_rtna, bwd_v3_name, io_perm, 32, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, true, false, 2>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz";
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_hd64_bf16_a32_rtz, bwd_v3_name, io_perm, 32, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r;
}}
}}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne";
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, false, false, 0>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_hd64_bf16_a16_rtne, bwd_v3_name, io_perm, 32, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna";
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, false, false, 1>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_hd64_bf16_a16_rtna, bwd_v3_name, io_perm, 32, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz";
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, false, false, 2>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_hd64_bf16_a16_rtz, bwd_v3_name, io_perm, 32, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r;
}}
}}
......@@ -794,49 +941,55 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, true, false, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne";
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_hd64_bf16_causal_a32_rtne, bwd_v3_name, io_perm, 32, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, true, false, 1>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna";
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_hd64_bf16_causal_a32_rtna, bwd_v3_name, io_perm, 32, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, true, false, 2>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz";
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_hd64_bf16_causal_a32_rtz, bwd_v3_name, io_perm, 32, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r;
}}
}}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, false, false, 0>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_hd64_bf16_causal_a16_rtne, bwd_v3_name, io_perm, 32, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna";
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, false, false, 1>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_hd64_bf16_causal_a16_rtna, bwd_v3_name, io_perm, 32, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz";
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, false, false, 2>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_hd64_bf16_causal_a16_rtz, bwd_v3_name, io_perm, 32, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r;
}}
}}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment