"...composable_kernel_rocm.git" did not exist on "3af8c81a72b5b5a0155eb0e95c4f0aba1b375cca"
Commit 92d458d5 authored by danyao12's avatar danyao12
Browse files

Merge branch 'ck_tile/fa_bwd_v3' into ck_tile/fa_bwd_v3_semi

parents 1863752c 2defe2f6
...@@ -67,7 +67,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") ...@@ -67,7 +67,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
# to be included in "make all/install/check" # to be included in "make all/install/check"
message("adding example ${EXAMPLE_FMHA_BWD}") message("adding example ${EXAMPLE_FMHA_BWD}")
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL hsaco/bwd_hd64_bf16_a16_rtna.cpp hsaco/bwd_hd64_bf16_a16_rtne.cpp hsaco/bwd_hd64_bf16_a16_rtz.cpp hsaco/bwd_hd64_bf16_causal_a16_rtna.cpp hsaco/bwd_hd64_bf16_causal_a16_rtne.cpp hsaco/bwd_hd64_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_a16_rtna.cpp hsaco/bwd_bf16_a16_rtne.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32_rtna.cpp hsaco/bwd_bf16_a32_rtne.cpp hsaco/bwd_bf16_a32_rtz.cpp hsaco/bwd_bf16_causal_a16_rtna.cpp hsaco/bwd_bf16_causal_a16_rtne.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32_rtna.cpp hsaco/bwd_bf16_causal_a32_rtne.cpp hsaco/bwd_bf16_causal_a32_rtz.cpp hsaco/bwd_bf16_spec_a32.cpp hsaco/bwd_bf16_spec_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_spec_a32.cpp hsaco/bwd_fp16_spec_causal_a32.cpp fmha_bwd.cpp) add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL hsaco/bwd_hd64_bf16_a16_rtna.cpp hsaco/bwd_hd64_bf16_a16_rtne.cpp hsaco/bwd_hd64_bf16_a16_rtz.cpp hsaco/bwd_hd64_bf16_a32_rtna.cpp hsaco/bwd_hd64_bf16_a32_rtne.cpp hsaco/bwd_hd64_bf16_a32_rtz.cpp hsaco/bwd_hd64_bf16_causal_a16_rtna.cpp hsaco/bwd_hd64_bf16_causal_a16_rtne.cpp hsaco/bwd_hd64_bf16_causal_a16_rtz.cpp hsaco/bwd_hd64_bf16_causal_a32_rtna.cpp hsaco/bwd_hd64_bf16_causal_a32_rtne.cpp hsaco/bwd_hd64_bf16_causal_a32_rtz.cpp hsaco/bwd_bf16_a16_rtna.cpp hsaco/bwd_bf16_a16_rtne.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32_rtna.cpp hsaco/bwd_bf16_a32_rtne.cpp hsaco/bwd_bf16_a32_rtz.cpp hsaco/bwd_bf16_causal_a16_rtna.cpp hsaco/bwd_bf16_causal_a16_rtne.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32_rtna.cpp hsaco/bwd_bf16_causal_a32_rtne.cpp hsaco/bwd_bf16_causal_a32_rtz.cpp hsaco/bwd_bf16_spec_a32.cpp hsaco/bwd_bf16_spec_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_spec_a32.cpp hsaco/bwd_fp16_spec_causal_a32.cpp fmha_bwd.cpp)
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
......
...@@ -740,7 +740,33 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -740,7 +740,33 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
else if((a.hdim_q == 64) && (a.hdim_v == 64) && (a.seqlen_k % 64 == 0)){{ else if((a.hdim_q == 64) && (a.hdim_v == 64) && (a.seqlen_k % 64 == 0)){{
if(t.data_type.compare("bf16") == 0){{ if(t.data_type.compare("bf16") == 0){{
if(t.mask_type == mask_enum::no_mask){{ if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_hd64_bf16_a32_rtne, bwd_v3_name, io_perm, 32, 192);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_hd64_bf16_a32_rtna, bwd_v3_name, io_perm, 32, 192);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_hd64_bf16_a32_rtz, bwd_v3_name, io_perm, 32, 192);
return r;
}}
}}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{ if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne"; const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne";
...@@ -765,7 +791,33 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -765,7 +791,33 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
}} }}
}} }}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_hd64_bf16_causal_a32_rtne, bwd_v3_name, io_perm, 32, 192);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_hd64_bf16_causal_a32_rtna, bwd_v3_name, io_perm, 32, 192);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_hd64_bf16_causal_a32_rtz, bwd_v3_name, io_perm, 32, 192);
return r;
}}
}}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{ if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne"; const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne";
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -26,6 +26,12 @@ extern unsigned char bwd_fp16_spec_causal_a32[]; ...@@ -26,6 +26,12 @@ extern unsigned char bwd_fp16_spec_causal_a32[];
extern unsigned char bwd_hd64_bf16_a16_rtna[]; extern unsigned char bwd_hd64_bf16_a16_rtna[];
extern unsigned char bwd_hd64_bf16_a16_rtne[]; extern unsigned char bwd_hd64_bf16_a16_rtne[];
extern unsigned char bwd_hd64_bf16_a16_rtz[]; extern unsigned char bwd_hd64_bf16_a16_rtz[];
extern unsigned char bwd_hd64_bf16_a32_rtna[];
extern unsigned char bwd_hd64_bf16_a32_rtne[];
extern unsigned char bwd_hd64_bf16_a32_rtz[];
extern unsigned char bwd_hd64_bf16_causal_a16_rtna[]; extern unsigned char bwd_hd64_bf16_causal_a16_rtna[];
extern unsigned char bwd_hd64_bf16_causal_a16_rtne[]; extern unsigned char bwd_hd64_bf16_causal_a16_rtne[];
extern unsigned char bwd_hd64_bf16_causal_a16_rtz[]; extern unsigned char bwd_hd64_bf16_causal_a16_rtz[];
extern unsigned char bwd_hd64_bf16_causal_a32_rtna[];
extern unsigned char bwd_hd64_bf16_causal_a32_rtne[];
extern unsigned char bwd_hd64_bf16_causal_a32_rtz[];
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment