Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
cf70a2ef
Commit
cf70a2ef
authored
Feb 08, 2025
by
danyao12
Browse files
rename
parent
dc3d35a9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
60 additions
and
60 deletions
+60
-60
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+60
-60
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
cf70a2ef
...
@@ -189,7 +189,7 @@ struct p2
...
@@ -189,7 +189,7 @@ struct p2
unsigned int _p1;
unsigned int _p1;
}};
}};
struct __attribute__((packed)) fmha_bwd_
xqa_
v3_args
struct __attribute__((packed)) fmha_bwd_v3_args
{{
{{
void* ptr_dq;
void* ptr_dq;
p2 _p0;
p2 _p0;
...
@@ -235,7 +235,7 @@ struct __attribute__((packed)) fmha_bwd_xqa_v3_args
...
@@ -235,7 +235,7 @@ struct __attribute__((packed)) fmha_bwd_xqa_v3_args
p3 _p20;
p3 _p20;
}};
}};
struct __attribute__((packed)) fmha_bwd_
xqa_v3_dp
_args
struct __attribute__((packed)) fmha_bwd_
v3_gen
_args
{{
{{
void* ptr_dq;
void* ptr_dq;
p2 _p0;
p2 _p0;
...
@@ -474,7 +474,7 @@ class fmha_bwd_v3_kernel
...
@@ -474,7 +474,7 @@ class fmha_bwd_v3_kernel
}}
}}
void
void
launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_
xqa_
v3_args args, const ck_tile::stream_config& s) const
launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_v3_args args, const ck_tile::stream_config& s) const
{{
{{
size_t arg_size = sizeof(args);
size_t arg_size = sizeof(args);
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
...
@@ -506,7 +506,7 @@ class fmha_bwd_v3_kernel
...
@@ -506,7 +506,7 @@ class fmha_bwd_v3_kernel
}}
}}
void
void
launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_
xqa_v3_dp
_args args, const ck_tile::stream_config& s) const
launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_
v3_gen
_args args, const ck_tile::stream_config& s) const
{{
{{
size_t arg_size = sizeof(args);
size_t arg_size = sizeof(args);
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
...
@@ -555,11 +555,11 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
...
@@ -555,11 +555,11 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
}}
}}
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_>
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_>
float fmha_bwd_v3_
xqa_
(const ck_tile::stream_config& s, fmha_bwd_args a)
float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
{{
if(s.log_level_ > 0)
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << std::flush;
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << std::flush;
fmha_bwd_
xqa_
v3_args args;
fmha_bwd_v3_args args;
args.ptr_dq = a.dq_ptr;
args.ptr_dq = a.dq_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
args.ptr_dv = a.dv_ptr;
...
@@ -598,11 +598,11 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
...
@@ -598,11 +598,11 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
}}
}}
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_>
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_>
float fmha_bwd_v3_
hdp_xqa
_(const ck_tile::stream_config& s, fmha_bwd_args a)
float fmha_bwd_v3_
gen
_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
{{
if(s.log_level_ > 0)
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << std::flush;
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << std::flush;
fmha_bwd_
xqa_v3_dp
_args args;
fmha_bwd_
v3_gen
_args args;
args.ptr_dq = a.dq_ptr;
args.ptr_dq = a.dq_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
args.ptr_dv = a.dv_ptr;
...
@@ -642,11 +642,11 @@ float fmha_bwd_v3_hdp_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
...
@@ -642,11 +642,11 @@ float fmha_bwd_v3_hdp_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
}}
}}
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
float fmha_bwd_v3_
xqa_
(const ck_tile::stream_config& s, fmha_bwd_args a)
float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
{{
if(s.log_level_ > 0)
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
fmha_bwd_
xqa_
v3_args args;
fmha_bwd_v3_args args;
args.ptr_dq = a.dq_acc_ptr;
args.ptr_dq = a.dq_acc_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
args.ptr_dv = a.dv_ptr;
...
@@ -686,11 +686,11 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
...
@@ -686,11 +686,11 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
}}
}}
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
float fmha_bwd_v3_
hdp_xqa
_(const ck_tile::stream_config& s, fmha_bwd_args a)
float fmha_bwd_v3_
gen
_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
{{
if(s.log_level_ > 0)
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
fmha_bwd_
xqa_v3_dp
_args args;
fmha_bwd_
v3_gen
_args args;
args.ptr_dq = a.dq_acc_ptr;
args.ptr_dq = a.dq_acc_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
args.ptr_dv = a.dv_ptr;
...
@@ -747,7 +747,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -747,7 +747,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a32";
// const std::string bwd_v3_name = "bwd_v3_fp16_a32";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
else{{
else{{
...
@@ -755,7 +755,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -755,7 +755,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a32_pddv";
// const std::string bwd_v3_name = "bwd_v3_fp16_a32_pddv";
r = fmha_bwd_v3_
hdp_xqa
_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_
gen
_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -764,14 +764,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -764,14 +764,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a16";
// const std::string bwd_v3_name = "bwd_v3_fp16_a16";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
else{{
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, true>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a16_pddv";
// const std::string bwd_v3_name = "bwd_v3_fp16_a16_pddv";
r = fmha_bwd_v3_
hdp_xqa
_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_
gen
_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -783,7 +783,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -783,7 +783,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32";
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
else{{
else{{
...
@@ -791,7 +791,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -791,7 +791,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_pddv";
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_pddv";
r = fmha_bwd_v3_
hdp_xqa
_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_
gen
_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -800,14 +800,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -800,14 +800,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16";
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
else{{
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, true>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16_pddv";
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16_pddv";
r = fmha_bwd_v3_
hdp_xqa
_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_
gen
_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -822,7 +822,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -822,7 +822,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne";
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
else{{
else{{
...
@@ -830,7 +830,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -830,7 +830,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne_pddv";
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne_pddv";
r = fmha_bwd_v3_
hdp_xqa
_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_
gen
_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -840,7 +840,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -840,7 +840,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna";
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
else{{
else{{
...
@@ -848,7 +848,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -848,7 +848,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna_pddv";
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna_pddv";
r = fmha_bwd_v3_
hdp_xqa
_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_
gen
_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -858,7 +858,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -858,7 +858,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz";
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
else{{
else{{
...
@@ -866,7 +866,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -866,7 +866,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz_pddv";
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz_pddv";
r = fmha_bwd_v3_
hdp_xqa
_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_
gen
_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -877,14 +877,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -877,14 +877,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne";
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
else{{
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne_pddv";
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne_pddv";
r = fmha_bwd_v3_
hdp_xqa
_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_
gen
_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -893,14 +893,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -893,14 +893,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna";
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
else{{
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna_pddv";
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna_pddv";
r = fmha_bwd_v3_
hdp_xqa
_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_
gen
_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -909,14 +909,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -909,14 +909,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz";
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
else{{
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz_pddv";
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz_pddv";
r = fmha_bwd_v3_
hdp_xqa
_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_
gen
_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -930,7 +930,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -930,7 +930,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne";
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
else{{
else{{
...
@@ -938,7 +938,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -938,7 +938,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_pddv";
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_pddv";
r = fmha_bwd_v3_
hdp_xqa
_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_
gen
_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -948,7 +948,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -948,7 +948,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna";
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
else{{
else{{
...
@@ -956,7 +956,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -956,7 +956,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_pddv";
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_pddv";
r = fmha_bwd_v3_
hdp_xqa
_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_
gen
_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -966,7 +966,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -966,7 +966,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz";
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
else{{
else{{
...
@@ -974,7 +974,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -974,7 +974,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_pddv";
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_pddv";
r = fmha_bwd_v3_
hdp_xqa
_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_
gen
_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -985,14 +985,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -985,14 +985,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne";
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
else{{
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne_pddv";
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne_pddv";
r = fmha_bwd_v3_
hdp_xqa
_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_
gen
_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -1001,14 +1001,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -1001,14 +1001,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna";
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
else{{
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna_pddv";
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna_pddv";
r = fmha_bwd_v3_
hdp_xqa
_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_
gen
_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -1017,14 +1017,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -1017,14 +1017,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz";
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
else{{
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz_pddv";
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz_pddv";
r = fmha_bwd_v3_
hdp_xqa
_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_
gen
_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -1040,14 +1040,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -1040,14 +1040,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32";
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
else if(t.is_v3_atomic_fp32 == false){{
else if(t.is_v3_atomic_fp32 == false){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, false, 0, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a16";
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a16";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -1057,14 +1057,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -1057,14 +1057,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32";
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
else if(t.is_v3_atomic_fp32 == false){{
else if(t.is_v3_atomic_fp32 == false){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, false, 0, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a16";
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a16";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -1077,7 +1077,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -1077,7 +1077,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne";
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 1){{
else if(t.how_v3_bf16_cvt == 1){{
...
@@ -1085,7 +1085,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -1085,7 +1085,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna";
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 2){{
else if(t.how_v3_bf16_cvt == 2){{
...
@@ -1093,7 +1093,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -1093,7 +1093,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz";
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -1102,21 +1102,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -1102,21 +1102,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 0, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne";
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 1){{
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 1, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 1, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna";
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 2){{
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 2, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 2, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz";
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -1128,7 +1128,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -1128,7 +1128,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne";
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 1){{
else if(t.how_v3_bf16_cvt == 1){{
...
@@ -1136,7 +1136,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -1136,7 +1136,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna";
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 2){{
else if(t.how_v3_bf16_cvt == 2){{
...
@@ -1144,7 +1144,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -1144,7 +1144,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz";
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
@@ -1153,21 +1153,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -1153,21 +1153,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 0, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 0, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne";
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 1){{
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 1, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 1, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna";
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 2){{
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 2, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 2, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz";
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz";
r = fmha_bwd_v3_
xqa_
<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
return r;
}}
}}
}}
}}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment