Commit cf70a2ef authored by danyao12's avatar danyao12
Browse files

rename

parent dc3d35a9
...@@ -189,7 +189,7 @@ struct p2 ...@@ -189,7 +189,7 @@ struct p2
unsigned int _p1; unsigned int _p1;
}}; }};
struct __attribute__((packed)) fmha_bwd_xqa_v3_args struct __attribute__((packed)) fmha_bwd_v3_args
{{ {{
void* ptr_dq; void* ptr_dq;
p2 _p0; p2 _p0;
...@@ -235,7 +235,7 @@ struct __attribute__((packed)) fmha_bwd_xqa_v3_args ...@@ -235,7 +235,7 @@ struct __attribute__((packed)) fmha_bwd_xqa_v3_args
p3 _p20; p3 _p20;
}}; }};
struct __attribute__((packed)) fmha_bwd_xqa_v3_dp_args struct __attribute__((packed)) fmha_bwd_v3_gen_args
{{ {{
void* ptr_dq; void* ptr_dq;
p2 _p0; p2 _p0;
...@@ -474,7 +474,7 @@ class fmha_bwd_v3_kernel ...@@ -474,7 +474,7 @@ class fmha_bwd_v3_kernel
}} }}
void void
launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_xqa_v3_args args, const ck_tile::stream_config& s) const launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_v3_args args, const ck_tile::stream_config& s) const
{{ {{
size_t arg_size = sizeof(args); size_t arg_size = sizeof(args);
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER, void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
...@@ -506,7 +506,7 @@ class fmha_bwd_v3_kernel ...@@ -506,7 +506,7 @@ class fmha_bwd_v3_kernel
}} }}
void void
launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_xqa_v3_dp_args args, const ck_tile::stream_config& s) const launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_v3_gen_args args, const ck_tile::stream_config& s) const
{{ {{
size_t arg_size = sizeof(args); size_t arg_size = sizeof(args);
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER, void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
...@@ -555,11 +555,11 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) ...@@ -555,11 +555,11 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
}} }}
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_> template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_>
float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a) float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{ {{
if(s.log_level_ > 0) if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << std::flush; std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << std::flush;
fmha_bwd_xqa_v3_args args; fmha_bwd_v3_args args;
args.ptr_dq = a.dq_ptr; args.ptr_dq = a.dq_ptr;
args.ptr_dk = a.dk_ptr; args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr; args.ptr_dv = a.dv_ptr;
...@@ -598,11 +598,11 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a) ...@@ -598,11 +598,11 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
}} }}
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_> template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_>
float fmha_bwd_v3_hdp_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a) float fmha_bwd_v3_gen_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{ {{
if(s.log_level_ > 0) if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << std::flush; std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << std::flush;
fmha_bwd_xqa_v3_dp_args args; fmha_bwd_v3_gen_args args;
args.ptr_dq = a.dq_ptr; args.ptr_dq = a.dq_ptr;
args.ptr_dk = a.dk_ptr; args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr; args.ptr_dv = a.dv_ptr;
...@@ -642,11 +642,11 @@ float fmha_bwd_v3_hdp_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a) ...@@ -642,11 +642,11 @@ float fmha_bwd_v3_hdp_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
}} }}
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_> template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a) float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{ {{
if(s.log_level_ > 0) if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush; std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
fmha_bwd_xqa_v3_args args; fmha_bwd_v3_args args;
args.ptr_dq = a.dq_acc_ptr; args.ptr_dq = a.dq_acc_ptr;
args.ptr_dk = a.dk_ptr; args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr; args.ptr_dv = a.dv_ptr;
...@@ -686,11 +686,11 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a) ...@@ -686,11 +686,11 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
}} }}
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_> template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
float fmha_bwd_v3_hdp_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a) float fmha_bwd_v3_gen_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{ {{
if(s.log_level_ > 0) if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush; std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
fmha_bwd_xqa_v3_dp_args args; fmha_bwd_v3_gen_args args;
args.ptr_dq = a.dq_acc_ptr; args.ptr_dq = a.dq_acc_ptr;
args.ptr_dk = a.dk_ptr; args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr; args.ptr_dv = a.dv_ptr;
...@@ -747,7 +747,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -747,7 +747,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a32"; // const std::string bwd_v3_name = "bwd_v3_fp16_a32";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else{{ else{{
...@@ -755,7 +755,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -755,7 +755,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a32_pddv"; // const std::string bwd_v3_name = "bwd_v3_fp16_a32_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -764,14 +764,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -764,14 +764,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a16"; // const std::string bwd_v3_name = "bwd_v3_fp16_a16";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else{{ else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, true>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a16_pddv"; // const std::string bwd_v3_name = "bwd_v3_fp16_a16_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -783,7 +783,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -783,7 +783,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32"; // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else{{ else{{
...@@ -791,7 +791,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -791,7 +791,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_pddv"; // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -800,14 +800,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -800,14 +800,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16"; // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else{{ else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, true>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16_pddv"; // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -822,7 +822,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -822,7 +822,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne"; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else{{ else{{
...@@ -830,7 +830,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -830,7 +830,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne_pddv"; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -840,7 +840,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -840,7 +840,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna"; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else{{ else{{
...@@ -848,7 +848,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -848,7 +848,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna_pddv"; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -858,7 +858,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -858,7 +858,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz"; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else{{ else{{
...@@ -866,7 +866,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -866,7 +866,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz_pddv"; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -877,14 +877,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -877,14 +877,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne"; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else{{ else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne_pddv"; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -893,14 +893,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -893,14 +893,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna"; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else{{ else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna_pddv"; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -909,14 +909,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -909,14 +909,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz"; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else{{ else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz_pddv"; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -930,7 +930,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -930,7 +930,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else{{ else{{
...@@ -938,7 +938,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -938,7 +938,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_pddv"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -948,7 +948,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -948,7 +948,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else{{ else{{
...@@ -956,7 +956,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -956,7 +956,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_pddv"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -966,7 +966,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -966,7 +966,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else{{ else{{
...@@ -974,7 +974,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -974,7 +974,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_pddv"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -985,14 +985,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -985,14 +985,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else{{ else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne_pddv"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -1001,14 +1001,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -1001,14 +1001,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else{{ else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna_pddv"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -1017,14 +1017,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -1017,14 +1017,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else{{ else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz_pddv"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -1040,14 +1040,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -1040,14 +1040,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32"; // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else if(t.is_v3_atomic_fp32 == false){{ else if(t.is_v3_atomic_fp32 == false){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, false, 0, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a16"; // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a16";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -1057,14 +1057,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -1057,14 +1057,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32"; // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else if(t.is_v3_atomic_fp32 == false){{ else if(t.is_v3_atomic_fp32 == false){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, false, 0, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a16"; // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a16";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -1077,7 +1077,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -1077,7 +1077,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
...@@ -1085,7 +1085,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -1085,7 +1085,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
...@@ -1093,7 +1093,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -1093,7 +1093,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -1102,21 +1102,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -1102,21 +1102,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 0, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 1, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 1, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 2, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 2, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -1128,7 +1128,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -1128,7 +1128,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
...@@ -1136,7 +1136,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -1136,7 +1136,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
...@@ -1144,7 +1144,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -1144,7 +1144,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r; return r;
}} }}
}} }}
...@@ -1153,21 +1153,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -1153,21 +1153,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 0, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 0, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne"; const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 1, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 1, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 2, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 2, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a); r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r; return r;
}} }}
}} }}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment