Commit 0c126ffc authored by danyao12's avatar danyao12
Browse files

qdo/kv strides split

parent 43d81903
...@@ -495,7 +495,7 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) ...@@ -495,7 +495,7 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
}} }}
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_> template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_>
float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io_perm) float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{ {{
if(s.log_level_ > 0) if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << std::flush; std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << std::flush;
...@@ -513,35 +513,16 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io ...@@ -513,35 +513,16 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io
args.log2e = ck_tile::log2e_v<float>; args.log2e = ck_tile::log2e_v<float>;
args.seq_len = a.seqlen_q; args.seq_len = a.seqlen_q;
int stride_tg = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * a.hdim_q * 2; args.Ts = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * a.stride_k * 2;
int stride_head = a.seqlen_q * a.hdim_q * 2; args.Hs = a.nhead_stride_q * 2;
int stride_batch = a.nhead_q * a.seqlen_q * a.hdim_q * 2; args.BAs = a.batch_stride_q * 2;
int stride_seqlen = a.hdim_q * 2; args.Seqs = a.stride_q * 2;
int stride_head_kv = a.seqlen_q * a.hdim_q * 2;
int stride_batch_kv = a.nhead_k * a.seqlen_q * a.hdim_q * 2;
int stride_seqlen_kv = a.hdim_q * 2;
int stride_seqlen_dkv = a.hdim_q * 2;
if(io_perm == 0) //BSHD
{{
stride_seqlen = a.nhead_q * a.hdim_q * 2;
stride_head = a.hdim_q * 2;
stride_seqlen_kv = a.nhead_k * a.hdim_q * 2;
stride_seqlen_dkv = a.nhead_q * a.hdim_q * 2;
stride_tg = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * stride_seqlen_kv;
stride_head_kv = a.hdim_q * 2;
}}
args.Ts = stride_tg;
args.Hs = stride_head;
args.BAs = stride_batch;
args.Seqs = stride_seqlen;
args.ratio = a.nhead_q / a.nhead_k; args.ratio = a.nhead_q / a.nhead_k;
args.Hs_kv = stride_head_kv; args.Hs_kv = a.nhead_stride_k * 2;
args.BAs_kv = stride_batch_kv; args.BAs_kv = a.batch_stride_k * 2;
args.Seqs_kv = stride_seqlen_kv; args.Seqs_kv = a.stride_k * 2;
args.Seqs_dkv = stride_seqlen_dkv; args.Seqs_dkv = a.stride_dk * 2;
auto traits = fmha_bwd_v3_traits{{a.batch, auto traits = fmha_bwd_v3_traits{{a.batch,
a.nhead_q, a.nhead_q,
a.seqlen_q, a.seqlen_q,
...@@ -607,7 +588,7 @@ float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io_per ...@@ -607,7 +588,7 @@ float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io_per
}} }}
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_> template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io_perm) float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{ {{
if(s.log_level_ > 0) if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush; std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
...@@ -625,35 +606,16 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io ...@@ -625,35 +606,16 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io
args.log2e = ck_tile::log2e_v<float>; args.log2e = ck_tile::log2e_v<float>;
args.seq_len = a.seqlen_q; args.seq_len = a.seqlen_q;
int stride_tg = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * a.hdim_q * 2; args.Ts = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * a.stride_k * 2;
int stride_head = a.seqlen_q * a.hdim_q * 2; args.Hs = a.nhead_stride_q * 2;
int stride_batch = a.nhead_q * a.seqlen_q * a.hdim_q * 2; args.BAs = a.batch_stride_q * 2;
int stride_seqlen = a.hdim_q * 2; args.Seqs = a.stride_q * 2;
int stride_head_kv = a.seqlen_q * a.hdim_q * 2;
int stride_batch_kv = a.nhead_k * a.seqlen_q * a.hdim_q * 2;
int stride_seqlen_kv = a.hdim_q * 2;
int stride_seqlen_dkv = a.hdim_q * 2;
if(io_perm == 0) //BSHD
{{
stride_seqlen = a.nhead_q * a.hdim_q * 2;
stride_head = a.hdim_q * 2;
stride_seqlen_kv = a.nhead_k * a.hdim_q * 2;
stride_seqlen_dkv = a.nhead_q * a.hdim_q * 2;
stride_tg = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * stride_seqlen_kv;
stride_head_kv = a.hdim_q * 2;
}}
args.Ts = stride_tg;
args.Hs = stride_head;
args.BAs = stride_batch;
args.Seqs = stride_seqlen;
args.ratio = a.nhead_q / a.nhead_k; args.ratio = a.nhead_q / a.nhead_k;
args.Hs_kv = stride_head_kv; args.Hs_kv = a.nhead_stride_k * 2;
args.BAs_kv = stride_batch_kv; args.BAs_kv = a.batch_stride_k * 2;
args.Seqs_kv = stride_seqlen_kv; args.Seqs_kv = a.stride_k * 2;
args.Seqs_dkv = stride_seqlen_dkv; args.Seqs_dkv = a.stride_dk * 2;
auto traits = fmha_bwd_v3_traits{{a.batch, auto traits = fmha_bwd_v3_traits{{a.batch,
a.nhead_q, a.nhead_q,
a.seqlen_q, a.seqlen_q,
...@@ -694,8 +656,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -694,8 +656,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, false, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a32"; // const std::string bwd_v3_name = "bwd_v3_fp16_a32";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r; return r;
}} }}
}} }}
...@@ -703,8 +664,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -703,8 +664,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, false, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, false, false, 0>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a16"; // const std::string bwd_v3_name = "bwd_v3_fp16_a16";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r; return r;
}} }}
}} }}
...@@ -724,8 +684,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -724,8 +684,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, false, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32"; // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r; return r;
}} }}
}} }}
...@@ -733,8 +692,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -733,8 +692,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, false, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, false, false, 0>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16"; // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r; return r;
}} }}
}} }}
...@@ -757,8 +715,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -757,8 +715,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne"; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
...@@ -766,8 +723,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -766,8 +723,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 1>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 1>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna"; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
...@@ -775,8 +731,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -775,8 +731,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 2>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 2>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz"; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r; return r;
}} }}
}} }}
...@@ -786,24 +741,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -786,24 +741,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 0>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne"; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 1>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 1>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna"; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 2>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 2>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz"; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r; return r;
}} }}
}} }}
...@@ -825,8 +777,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -825,8 +777,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
...@@ -834,8 +785,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -834,8 +785,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 1>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 1>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
...@@ -843,8 +793,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -843,8 +793,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 2>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 2>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r; return r;
}} }}
}} }}
...@@ -854,24 +803,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -854,24 +803,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 0>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 1>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 1>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 2>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 2>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz"; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r; return r;
}} }}
}} }}
...@@ -887,8 +833,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -887,8 +833,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, true, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, true, false, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
...@@ -896,8 +841,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -896,8 +841,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, true, false, 1>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, true, false, 1>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
...@@ -905,8 +849,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -905,8 +849,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, true, false, 2>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, true, false, 2>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r; return r;
}} }}
}} }}
...@@ -915,24 +858,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -915,24 +858,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, false, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, false, false, 0>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, false, false, 1>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, false, false, 1>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, false, false, 2>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, false, false, 2>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r; return r;
}} }}
}} }}
...@@ -944,8 +884,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -944,8 +884,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, true, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, true, false, 0>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
...@@ -953,8 +892,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -953,8 +892,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, true, false, 1>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, true, false, 1>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
...@@ -962,8 +900,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -962,8 +900,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, true, false, 2>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, true, false, 2>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
return r; return r;
}} }}
}} }}
...@@ -972,24 +909,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -972,24 +909,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, false, false, 0>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, false, false, 0>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne"; const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, false, false, 1>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, false, false, 1>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, false, false, 2>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, false, false, 2>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz"; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
return r; return r;
}} }}
}} }}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment