Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
66cbdd6c
Commit
66cbdd6c
authored
Dec 17, 2024
by
danyao12
Browse files
fav3 bwd hd64 bf16 a16 verification passed
parent
55d982c3
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
2372 additions
and
130 deletions
+2372
-130
example/ck_tile/01_fmha/CMakeLists.txt
example/ck_tile/01_fmha/CMakeLists.txt
+1
-1
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+145
-129
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_a16_rtna.cpp
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_a16_rtna.cpp
+2225
-0
example/ck_tile/01_fmha/hsaco/fmha_hsaco.hpp
example/ck_tile/01_fmha/hsaco/fmha_hsaco.hpp
+1
-0
No files found.
example/ck_tile/01_fmha/CMakeLists.txt
View file @
66cbdd6c
...
@@ -67,7 +67,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
...
@@ -67,7 +67,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
# to be included in "make all/install/check"
# to be included in "make all/install/check"
message
(
"adding example
${
EXAMPLE_FMHA_BWD
}
"
)
message
(
"adding example
${
EXAMPLE_FMHA_BWD
}
"
)
add_executable
(
${
EXAMPLE_FMHA_BWD
}
EXCLUDE_FROM_ALL hsaco/bwd_bf16_a16_rtna.cpp hsaco/bwd_bf16_a16_rtne.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32_rtna.cpp hsaco/bwd_bf16_a32_rtne.cpp hsaco/bwd_bf16_a32_rtz.cpp hsaco/bwd_bf16_causal_a16_rtna.cpp hsaco/bwd_bf16_causal_a16_rtne.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32_rtna.cpp hsaco/bwd_bf16_causal_a32_rtne.cpp hsaco/bwd_bf16_causal_a32_rtz.cpp hsaco/bwd_bf16_spec_a32.cpp hsaco/bwd_bf16_spec_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_spec_a32.cpp hsaco/bwd_fp16_spec_causal_a32.cpp fmha_bwd.cpp
)
add_executable
(
${
EXAMPLE_FMHA_BWD
}
EXCLUDE_FROM_ALL
hsaco/bwd_hd64_bf16_a16_rtna.cpp
hsaco/bwd_bf16_a16_rtna.cpp hsaco/bwd_bf16_a16_rtne.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32_rtna.cpp hsaco/bwd_bf16_a32_rtne.cpp hsaco/bwd_bf16_a32_rtz.cpp hsaco/bwd_bf16_causal_a16_rtna.cpp hsaco/bwd_bf16_causal_a16_rtne.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32_rtna.cpp hsaco/bwd_bf16_causal_a32_rtne.cpp hsaco/bwd_bf16_causal_a32_rtz.cpp hsaco/bwd_bf16_spec_a32.cpp hsaco/bwd_bf16_spec_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_spec_a32.cpp hsaco/bwd_fp16_spec_causal_a32.cpp fmha_bwd.cpp
)
target_include_directories
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_include_directories
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_sources
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
FMHA_BWD_GEN_BLOBS
}
)
target_sources
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
FMHA_BWD_GEN_BLOBS
}
)
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
66cbdd6c
...
@@ -554,8 +554,8 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -554,8 +554,8 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
if (t.uses_bwd_v3 == true){{
if (t.uses_bwd_v3 == true){{
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
(a.seqlen_q == a.seqlen_k) && (
a.seqlen_k % 128 == 0) && (a.hdim_q == 128) && (a.hdim_v == 128) && (t.is_deterministic == false) &&
(a.seqlen_q == a.seqlen_k) && (
t.is_deterministic == false) && (a.stride_q == a.stride_o /*i_perm == o_perm*/)) {{
(a.stride_q == a.stride_o /*i_perm == o_perm*/
))
{{
if((a.hdim_q == 128) && (a.hdim_v == 128) && (a.seqlen_k % 128 == 0
)){{
if(t.data_type.compare("fp16") == 0){{
if(t.data_type.compare("fp16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
...
@@ -737,6 +737,22 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -737,6 +737,22 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
}}
}}
}}
}}
}}
}}
else if((a.hdim_q == 64) && (a.hdim_v == 64) && (a.seqlen_k % 64 == 0)){{
if(t.data_type.compare("bf16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_hd64_bf16_a16_rtna, bwd_v3_name, io_perm, 32, 192);
return r;
}}
}}
}}
}}
}}
}}
}}
}}
{F_dispatch}
{F_dispatch}
...
...
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_a16_rtna.cpp
0 → 100644
View file @
66cbdd6c
This source diff could not be displayed because it is too large. You can
view the blob
instead.
example/ck_tile/01_fmha/hsaco/fmha_hsaco.hpp
View file @
66cbdd6c
...
@@ -23,3 +23,4 @@ extern unsigned char bwd_fp16_causal_a16[];
...
@@ -23,3 +23,4 @@ extern unsigned char bwd_fp16_causal_a16[];
extern
unsigned
char
bwd_fp16_causal_a32
[];
extern
unsigned
char
bwd_fp16_causal_a32
[];
extern
unsigned
char
bwd_fp16_spec_a32
[];
extern
unsigned
char
bwd_fp16_spec_a32
[];
extern
unsigned
char
bwd_fp16_spec_causal_a32
[];
extern
unsigned
char
bwd_fp16_spec_causal_a32
[];
extern
unsigned
char
bwd_hd64_bf16_a16_rtna
[];
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment