"include/vscode:/vscode.git/clone" did not exist on "d1006d46bbda5b74224e75197b1496a6db5bdc48"
Commit 66cbdd6c authored by danyao12's avatar danyao12
Browse files

fav3 bwd hd64 bf16 a16 verification passed

parent 55d982c3
...@@ -67,7 +67,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") ...@@ -67,7 +67,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
# to be included in "make all/install/check" # to be included in "make all/install/check"
message("adding example ${EXAMPLE_FMHA_BWD}") message("adding example ${EXAMPLE_FMHA_BWD}")
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL hsaco/bwd_bf16_a16_rtna.cpp hsaco/bwd_bf16_a16_rtne.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32_rtna.cpp hsaco/bwd_bf16_a32_rtne.cpp hsaco/bwd_bf16_a32_rtz.cpp hsaco/bwd_bf16_causal_a16_rtna.cpp hsaco/bwd_bf16_causal_a16_rtne.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32_rtna.cpp hsaco/bwd_bf16_causal_a32_rtne.cpp hsaco/bwd_bf16_causal_a32_rtz.cpp hsaco/bwd_bf16_spec_a32.cpp hsaco/bwd_bf16_spec_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_spec_a32.cpp hsaco/bwd_fp16_spec_causal_a32.cpp fmha_bwd.cpp) add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL hsaco/bwd_hd64_bf16_a16_rtna.cpp hsaco/bwd_bf16_a16_rtna.cpp hsaco/bwd_bf16_a16_rtne.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32_rtna.cpp hsaco/bwd_bf16_a32_rtne.cpp hsaco/bwd_bf16_a32_rtz.cpp hsaco/bwd_bf16_causal_a16_rtna.cpp hsaco/bwd_bf16_causal_a16_rtne.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32_rtna.cpp hsaco/bwd_bf16_causal_a32_rtne.cpp hsaco/bwd_bf16_causal_a32_rtz.cpp hsaco/bwd_bf16_spec_a32.cpp hsaco/bwd_bf16_spec_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_spec_a32.cpp hsaco/bwd_fp16_spec_causal_a32.cpp fmha_bwd.cpp)
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
......
...@@ -554,184 +554,200 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -554,184 +554,200 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
if (t.uses_bwd_v3 == true){{ if (t.uses_bwd_v3 == true){{
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
(a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 128 == 0) && (a.hdim_q == 128) && (a.hdim_v == 128) && (t.is_deterministic == false) && (a.seqlen_q == a.seqlen_k) && (t.is_deterministic == false) && (a.stride_q == a.stride_o /*i_perm == o_perm*/)) {{
(a.stride_q == a.stride_o /*i_perm == o_perm*/)) {{ if((a.hdim_q == 128) && (a.hdim_v == 128) && (a.seqlen_k % 128 == 0)){{
if(t.data_type.compare("fp16") == 0){{ if(t.data_type.compare("fp16") == 0){{
if(t.mask_type == mask_enum::no_mask){{ if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{ if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{ if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_spec_a32"; const std::string bwd_v3_name = "bwd_v3_fp16_spec_a32";
bool io_perm = a.nhead_stride_q > a.stride_q; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_spec_a32, bwd_v3_name, io_perm, 32, 128); r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_spec_a32, bwd_v3_name, io_perm, 32, 128);
return r; return r;
}}
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}} }}
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{ else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; const std::string bwd_v3_name = "bwd_v3_fp16_a16";
const std::string bwd_v3_name = "bwd_v3_fp16_a32";
bool io_perm = a.nhead_stride_q > a.stride_q; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_v3_name, io_perm, 16, 192); r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_a16, bwd_v3_name, io_perm, 16, 192);
return r; return r;
}} }}
}} }}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
const std::string bwd_v3_name = "bwd_v3_fp16_a16"; if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
bool io_perm = a.nhead_stride_q > a.stride_q; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_a16, bwd_v3_name, io_perm, 16, 192); using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
return r; const std::string bwd_v3_name = "bwd_v3_fp16_spec_causal_a32";
}} bool io_perm = a.nhead_stride_q > a.stride_q;
}} r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_spec_causal_a32, bwd_v3_name, io_perm, 32, 128);
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ return r;
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{ }}
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{ else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_spec_causal_a32"; const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_spec_causal_a32, bwd_v3_name, io_perm, 32, 128); r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_causal_a32, bwd_v3_name, io_perm, 16, 192);
return r; return r;
}}
}} }}
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{ else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16";
const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_causal_a32, bwd_v3_name, io_perm, 16, 192); r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_causal_a16, bwd_v3_name, io_perm, 16, 192);
return r; return r;
}} }}
}} }}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_causal_a16, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}} }}
}} else if(t.data_type.compare("bf16") == 0){{
else if(t.data_type.compare("bf16") == 0){{ if(t.mask_type == mask_enum::no_mask){{
if(t.mask_type == mask_enum::no_mask){{ if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{ if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; const std::string bwd_v3_name = "bwd_v3_bf16_spec_a32";
const std::string bwd_v3_name = "bwd_v3_bf16_spec_a32"; bool io_perm = a.nhead_stride_q > a.stride_q;
bool io_perm = a.nhead_stride_q > a.stride_q; r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_spec_a32, bwd_v3_name, io_perm, 32, 128);
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_spec_a32, bwd_v3_name, io_perm, 32, 128); return r;
return r; }}
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32_rtne, bwd_v3_name, io_perm, 16, 192);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32_rtna, bwd_v3_name, io_perm, 16, 192);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32_rtz, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
}} }}
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{ else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{ if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne";
const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32_rtne, bwd_v3_name, io_perm, 16, 192); r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtne, bwd_v3_name, io_perm, 16, 192);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna";
const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32_rtna, bwd_v3_name, io_perm, 16, 192); r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtna, bwd_v3_name, io_perm, 16, 192);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz";
const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32_rtz, bwd_v3_name, io_perm, 16, 192); r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtz, bwd_v3_name, io_perm, 16, 192);
return r; return r;
}} }}
}} }}
}} }}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if(t.how_v3_bf16_cvt == 0){{ if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne"; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
bool io_perm = a.nhead_stride_q > a.stride_q; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtne, bwd_v3_name, io_perm, 16, 192); const std::string bwd_v3_name = "bwd_v3_bf16_spec_causal_a32";
return r; bool io_perm = a.nhead_stride_q > a.stride_q;
}} r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_spec_causal_a32, bwd_v3_name, io_perm, 32, 128);
else if(t.how_v3_bf16_cvt == 1){{ return r;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; }}
const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna"; else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
bool io_perm = a.nhead_stride_q > a.stride_q; if(t.how_v3_bf16_cvt == 0){{
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtna, bwd_v3_name, io_perm, 16, 192); using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
return r; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
}} const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne";
else if(t.how_v3_bf16_cvt == 2){{ bool io_perm = a.nhead_stride_q > a.stride_q;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32_rtne, bwd_v3_name, io_perm, 16, 192);
const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz"; return r;
bool io_perm = a.nhead_stride_q > a.stride_q; }}
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtz, bwd_v3_name, io_perm, 16, 192); else if(t.how_v3_bf16_cvt == 1){{
return r; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
}} using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
}} const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna";
}} bool io_perm = a.nhead_stride_q > a.stride_q;
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32_rtna, bwd_v3_name, io_perm, 16, 192);
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{ return r;
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{ }}
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; else if(t.how_v3_bf16_cvt == 2){{
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_spec_causal_a32"; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
bool io_perm = a.nhead_stride_q > a.stride_q; const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz";
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_spec_causal_a32, bwd_v3_name, io_perm, 32, 128); bool io_perm = a.nhead_stride_q > a.stride_q;
return r; r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32_rtz, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
}} }}
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{ else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{ if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne";
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32_rtne, bwd_v3_name, io_perm, 16, 192); r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtne, bwd_v3_name, io_perm, 16, 192);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 1){{ else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna";
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32_rtna, bwd_v3_name, io_perm, 16, 192); r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtna, bwd_v3_name, io_perm, 16, 192);
return r; return r;
}} }}
else if(t.how_v3_bf16_cvt == 2){{ else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz";
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32_rtz, bwd_v3_name, io_perm, 16, 192); r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtz, bwd_v3_name, io_perm, 16, 192);
return r; return r;
}} }}
}} }}
}} }}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ }}
if(t.how_v3_bf16_cvt == 0){{ }}
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; else if((a.hdim_q == 64) && (a.hdim_v == 64) && (a.seqlen_k % 64 == 0)){{
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne"; if(t.data_type.compare("bf16") == 0){{
bool io_perm = a.nhead_stride_q > a.stride_q; if(t.mask_type == mask_enum::no_mask){{
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtne, bwd_v3_name, io_perm, 16, 192); if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
return r; if(t.how_v3_bf16_cvt == 1){{
}} using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
else if(t.how_v3_bf16_cvt == 1){{ const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna";
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; bool io_perm = a.nhead_stride_q > a.stride_q;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna"; r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_hd64_bf16_a16_rtna, bwd_v3_name, io_perm, 32, 192);
bool io_perm = a.nhead_stride_q > a.stride_q; return r;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtna, bwd_v3_name, io_perm, 16, 192); }}
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtz, bwd_v3_name, io_perm, 16, 192);
return r;
}} }}
}} }}
}} }}
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -23,3 +23,4 @@ extern unsigned char bwd_fp16_causal_a16[]; ...@@ -23,3 +23,4 @@ extern unsigned char bwd_fp16_causal_a16[];
extern unsigned char bwd_fp16_causal_a32[]; extern unsigned char bwd_fp16_causal_a32[];
extern unsigned char bwd_fp16_spec_a32[]; extern unsigned char bwd_fp16_spec_a32[];
extern unsigned char bwd_fp16_spec_causal_a32[]; extern unsigned char bwd_fp16_spec_causal_a32[];
extern unsigned char bwd_hd64_bf16_a16_rtna[];
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment