Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
66cbdd6c
Commit
66cbdd6c
authored
Dec 17, 2024
by
danyao12
Browse files
fav3 bwd hd64 bf16 a16 verification passed
parent
55d982c3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
2372 additions
and
130 deletions
+2372
-130
example/ck_tile/01_fmha/CMakeLists.txt
example/ck_tile/01_fmha/CMakeLists.txt
+1
-1
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+145
-129
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_a16_rtna.cpp
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_a16_rtna.cpp
+2225
-0
example/ck_tile/01_fmha/hsaco/fmha_hsaco.hpp
example/ck_tile/01_fmha/hsaco/fmha_hsaco.hpp
+1
-0
No files found.
example/ck_tile/01_fmha/CMakeLists.txt
View file @
66cbdd6c
...
...
@@ -67,7 +67,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
# to be included in "make all/install/check"
message
(
"adding example
${
EXAMPLE_FMHA_BWD
}
"
)
add_executable
(
${
EXAMPLE_FMHA_BWD
}
EXCLUDE_FROM_ALL hsaco/bwd_bf16_a16_rtna.cpp hsaco/bwd_bf16_a16_rtne.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32_rtna.cpp hsaco/bwd_bf16_a32_rtne.cpp hsaco/bwd_bf16_a32_rtz.cpp hsaco/bwd_bf16_causal_a16_rtna.cpp hsaco/bwd_bf16_causal_a16_rtne.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32_rtna.cpp hsaco/bwd_bf16_causal_a32_rtne.cpp hsaco/bwd_bf16_causal_a32_rtz.cpp hsaco/bwd_bf16_spec_a32.cpp hsaco/bwd_bf16_spec_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_spec_a32.cpp hsaco/bwd_fp16_spec_causal_a32.cpp fmha_bwd.cpp
)
add_executable
(
${
EXAMPLE_FMHA_BWD
}
EXCLUDE_FROM_ALL
hsaco/bwd_hd64_bf16_a16_rtna.cpp
hsaco/bwd_bf16_a16_rtna.cpp hsaco/bwd_bf16_a16_rtne.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32_rtna.cpp hsaco/bwd_bf16_a32_rtne.cpp hsaco/bwd_bf16_a32_rtz.cpp hsaco/bwd_bf16_causal_a16_rtna.cpp hsaco/bwd_bf16_causal_a16_rtne.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32_rtna.cpp hsaco/bwd_bf16_causal_a32_rtne.cpp hsaco/bwd_bf16_causal_a32_rtz.cpp hsaco/bwd_bf16_spec_a32.cpp hsaco/bwd_bf16_spec_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_spec_a32.cpp hsaco/bwd_fp16_spec_causal_a32.cpp fmha_bwd.cpp
)
target_include_directories
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_sources
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
FMHA_BWD_GEN_BLOBS
}
)
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
66cbdd6c
...
...
@@ -554,184 +554,200 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
if (t.uses_bwd_v3 == true){{
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
(a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 128 == 0) && (a.hdim_q == 128) && (a.hdim_v == 128) && (t.is_deterministic == false) &&
(a.stride_q == a.stride_o /*i_perm == o_perm*/)) {{
if(t.data_type.compare("fp16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_spec_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_spec_a32, bwd_v3_name, io_perm, 32, 128);
return r;
(a.seqlen_q == a.seqlen_k) && (t.is_deterministic == false) && (a.stride_q == a.stride_o /*i_perm == o_perm*/)) {{
if((a.hdim_q == 128) && (a.hdim_v == 128) && (a.seqlen_k % 128 == 0)){{
if(t.data_type.compare("fp16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_spec_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_spec_a32, bwd_v3_name, io_perm, 32, 128);
return r;
}}
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
else if((t.is_v3_
spec
== false) && (a.nhead_q % a.nhead_k == 0)){{
else if((t.is_v3_
atomic_fp32
== false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_a32";
const std::string bwd_v3_name = "bwd_v3_fp16_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_
, convert_dq_trait_
>(s, a, bwd_fp16_a
32
, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_a
16
, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_a16, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_spec_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_spec_causal_a32, bwd_v3_name, io_perm, 32, 128);
return r;
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_spec_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_spec_causal_a32, bwd_v3_name, io_perm, 32, 128);
return r;
}}
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_causal_a32, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
else if((t.is_v3_
spec
== false) && (a.nhead_q % a.nhead_k == 0)){{
else if((t.is_v3_
atomic_fp32
== false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32";
const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_
, convert_dq_trait_
>(s, a, bwd_fp16_causal_a
32
, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_causal_a
16
, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_causal_a16, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
}}
else if(t.data_type.compare("bf16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_spec_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_spec_a32, bwd_v3_name, io_perm, 32, 128);
return r;
else if(t.data_type.compare("bf16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_spec_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_spec_a32, bwd_v3_name, io_perm, 32, 128);
return r;
}}
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32_rtne, bwd_v3_name, io_perm, 16, 192);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32_rtna, bwd_v3_name, io_perm, 16, 192);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32_rtz, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
}}
else if((t.is_v3_
spec
== false) && (a.nhead_q % a.nhead_k == 0)){{
else if((t.is_v3_
atomic_fp32
== false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne";
const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_
, convert_dq_trait_
>(s, a, bwd_bf16_a
32
_rtne, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a
16
_rtne, bwd_v3_name, io_perm, 16, 192);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna";
const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_
, convert_dq_trait_
>(s, a, bwd_bf16_a
32
_rtna, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a
16
_rtna, bwd_v3_name, io_perm, 16, 192);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz";
const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_
, convert_dq_trait_
>(s, a, bwd_bf16_a
32
_rtz, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a
16
_rtz, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
}}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtne, bwd_v3_name, io_perm, 16, 192);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtna, bwd_v3_name, io_perm, 16, 192);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtz, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_spec_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_spec_causal_a32, bwd_v3_name, io_perm, 32, 128);
return r;
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_spec_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_spec_causal_a32, bwd_v3_name, io_perm, 32, 128);
return r;
}}
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32_rtne, bwd_v3_name, io_perm, 16, 192);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32_rtna, bwd_v3_name, io_perm, 16, 192);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32_rtz, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
}}
else if((t.is_v3_
spec
== false) && (a.nhead_q % a.nhead_k == 0)){{
else if((t.is_v3_
atomic_fp32
== false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne";
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_
, convert_dq_trait_
>(s, a, bwd_bf16_causal_a
32
_rtne, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a
16
_rtne, bwd_v3_name, io_perm, 16, 192);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna";
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_
, convert_dq_trait_
>(s, a, bwd_bf16_causal_a
32
_rtna, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a
16
_rtna, bwd_v3_name, io_perm, 16, 192);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz";
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_
, convert_dq_trait_
>(s, a, bwd_bf16_causal_a
32
_rtz, bwd_v3_name, io_perm, 16, 192);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a
16
_rtz, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
}}
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtne, bwd_v3_name, io_perm, 16, 192);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtna, bwd_v3_name, io_perm, 16, 192);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtz, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
else if((a.hdim_q == 64) && (a.hdim_v == 64) && (a.seqlen_k % 64 == 0)){{
if(t.data_type.compare("bf16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_hd64_bf16_a16_rtna, bwd_v3_name, io_perm, 32, 192);
return r;
}}
}}
}}
}}
...
...
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_a16_rtna.cpp
0 → 100644
View file @
66cbdd6c
This source diff could not be displayed because it is too large. You can
view the blob
instead.
example/ck_tile/01_fmha/hsaco/fmha_hsaco.hpp
View file @
66cbdd6c
...
...
@@ -23,3 +23,4 @@ extern unsigned char bwd_fp16_causal_a16[];
extern
unsigned
char
bwd_fp16_causal_a32
[];
extern
unsigned
char
bwd_fp16_spec_a32
[];
extern
unsigned
char
bwd_fp16_spec_causal_a32
[];
extern
unsigned
char
bwd_hd64_bf16_a16_rtna
[];
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment