Commit cf70a2ef authored by danyao12's avatar danyao12
Browse files

rename

parent dc3d35a9
......@@ -189,7 +189,7 @@ struct p2
unsigned int _p1;
}};
struct __attribute__((packed)) fmha_bwd_xqa_v3_args
struct __attribute__((packed)) fmha_bwd_v3_args
{{
void* ptr_dq;
p2 _p0;
......@@ -235,7 +235,7 @@ struct __attribute__((packed)) fmha_bwd_xqa_v3_args
p3 _p20;
}};
struct __attribute__((packed)) fmha_bwd_xqa_v3_dp_args
struct __attribute__((packed)) fmha_bwd_v3_gen_args
{{
void* ptr_dq;
p2 _p0;
......@@ -474,7 +474,7 @@ class fmha_bwd_v3_kernel
}}
void
launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_xqa_v3_args args, const ck_tile::stream_config& s) const
launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_v3_args args, const ck_tile::stream_config& s) const
{{
size_t arg_size = sizeof(args);
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
......@@ -506,7 +506,7 @@ class fmha_bwd_v3_kernel
}}
void
launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_xqa_v3_dp_args args, const ck_tile::stream_config& s) const
launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_v3_gen_args args, const ck_tile::stream_config& s) const
{{
size_t arg_size = sizeof(args);
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
......@@ -555,11 +555,11 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
}}
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_>
float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << std::flush;
fmha_bwd_xqa_v3_args args;
fmha_bwd_v3_args args;
args.ptr_dq = a.dq_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
......@@ -598,11 +598,11 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
}}
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_>
float fmha_bwd_v3_hdp_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
float fmha_bwd_v3_gen_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << std::flush;
fmha_bwd_xqa_v3_dp_args args;
fmha_bwd_v3_gen_args args;
args.ptr_dq = a.dq_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
......@@ -642,11 +642,11 @@ float fmha_bwd_v3_hdp_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
}}
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
fmha_bwd_xqa_v3_args args;
fmha_bwd_v3_args args;
args.ptr_dq = a.dq_acc_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
......@@ -686,11 +686,11 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
}}
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
float fmha_bwd_v3_hdp_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
float fmha_bwd_v3_gen_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
fmha_bwd_xqa_v3_dp_args args;
fmha_bwd_v3_gen_args args;
args.ptr_dq = a.dq_acc_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
......@@ -747,7 +747,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a32";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
......@@ -755,7 +755,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a32_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
......@@ -764,14 +764,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a16";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, true>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a16_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
......@@ -783,7 +783,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
......@@ -791,7 +791,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
......@@ -800,14 +800,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, true>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
......@@ -822,7 +822,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
......@@ -830,7 +830,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
......@@ -840,7 +840,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
......@@ -848,7 +848,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
......@@ -858,7 +858,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
......@@ -866,7 +866,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
......@@ -877,14 +877,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
......@@ -893,14 +893,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
......@@ -909,14 +909,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
......@@ -930,7 +930,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
......@@ -938,7 +938,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
......@@ -948,7 +948,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
......@@ -956,7 +956,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
......@@ -966,7 +966,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
......@@ -974,7 +974,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
......@@ -985,14 +985,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
......@@ -1001,14 +1001,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
......@@ -1017,14 +1017,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz_pddv";
r = fmha_bwd_v3_hdp_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
......@@ -1040,14 +1040,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else if(t.is_v3_atomic_fp32 == false){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a16";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
......@@ -1057,14 +1057,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else if(t.is_v3_atomic_fp32 == false){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a16";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
......@@ -1077,7 +1077,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
......@@ -1085,7 +1085,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
......@@ -1093,7 +1093,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
......@@ -1102,21 +1102,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 1, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 2, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
......@@ -1128,7 +1128,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
......@@ -1136,7 +1136,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
......@@ -1144,7 +1144,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
......@@ -1153,21 +1153,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 0, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 1, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 2, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz";
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment