Commit 66cbdd6c authored by danyao12's avatar danyao12
Browse files

fav3 bwd hd64 bf16 a16 verification passed

parent 55d982c3
...@@ -67,7 +67,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") ...@@ -67,7 +67,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
# to be included in "make all/install/check" # to be included in "make all/install/check"
message("adding example ${EXAMPLE_FMHA_BWD}") message("adding example ${EXAMPLE_FMHA_BWD}")
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL hsaco/bwd_bf16_a16_rtna.cpp hsaco/bwd_bf16_a16_rtne.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32_rtna.cpp hsaco/bwd_bf16_a32_rtne.cpp hsaco/bwd_bf16_a32_rtz.cpp hsaco/bwd_bf16_causal_a16_rtna.cpp hsaco/bwd_bf16_causal_a16_rtne.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32_rtna.cpp hsaco/bwd_bf16_causal_a32_rtne.cpp hsaco/bwd_bf16_causal_a32_rtz.cpp hsaco/bwd_bf16_spec_a32.cpp hsaco/bwd_bf16_spec_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_spec_a32.cpp hsaco/bwd_fp16_spec_causal_a32.cpp fmha_bwd.cpp) add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL hsaco/bwd_hd64_bf16_a16_rtna.cpp hsaco/bwd_bf16_a16_rtna.cpp hsaco/bwd_bf16_a16_rtne.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32_rtna.cpp hsaco/bwd_bf16_a32_rtne.cpp hsaco/bwd_bf16_a32_rtz.cpp hsaco/bwd_bf16_causal_a16_rtna.cpp hsaco/bwd_bf16_causal_a16_rtne.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32_rtna.cpp hsaco/bwd_bf16_causal_a32_rtne.cpp hsaco/bwd_bf16_causal_a32_rtz.cpp hsaco/bwd_bf16_spec_a32.cpp hsaco/bwd_bf16_spec_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_spec_a32.cpp hsaco/bwd_fp16_spec_causal_a32.cpp fmha_bwd.cpp)
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
......
...@@ -554,8 +554,8 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -554,8 +554,8 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
if (t.uses_bwd_v3 == true){{ if (t.uses_bwd_v3 == true){{
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
(a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 128 == 0) && (a.hdim_q == 128) && (a.hdim_v == 128) && (t.is_deterministic == false) && (a.seqlen_q == a.seqlen_k) && (t.is_deterministic == false) && (a.stride_q == a.stride_o /*i_perm == o_perm*/)) {{
(a.stride_q == a.stride_o /*i_perm == o_perm*/)) {{ if((a.hdim_q == 128) && (a.hdim_v == 128) && (a.seqlen_k % 128 == 0)){{
if(t.data_type.compare("fp16") == 0){{ if(t.data_type.compare("fp16") == 0){{
if(t.mask_type == mask_enum::no_mask){{ if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{ if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
...@@ -737,6 +737,22 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -737,6 +737,22 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
}} }}
}} }}
}} }}
else if((a.hdim_q == 64) && (a.hdim_v == 64) && (a.seqlen_k % 64 == 0)){{
if(t.data_type.compare("bf16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_hd64_bf16_a16_rtna, bwd_v3_name, io_perm, 32, 192);
return r;
}}
}}
}}
}}
}}
}}
}} }}
{F_dispatch} {F_dispatch}
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -23,3 +23,4 @@ extern unsigned char bwd_fp16_causal_a16[]; ...@@ -23,3 +23,4 @@ extern unsigned char bwd_fp16_causal_a16[];
extern unsigned char bwd_fp16_causal_a32[]; extern unsigned char bwd_fp16_causal_a32[];
extern unsigned char bwd_fp16_spec_a32[]; extern unsigned char bwd_fp16_spec_a32[];
extern unsigned char bwd_fp16_spec_causal_a32[]; extern unsigned char bwd_fp16_spec_causal_a32[];
extern unsigned char bwd_hd64_bf16_a16_rtna[];
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment