Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
871c7556
Commit
871c7556
authored
Sep 29, 2024
by
danyao12
Browse files
add bf16+a16 rtz
parent
2dafca1f
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
4153 additions
and
15 deletions
+4153
-15
example/ck_tile/01_fmha/CMakeLists.txt
example/ck_tile/01_fmha/CMakeLists.txt
+1
-1
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+28
-10
example/ck_tile/01_fmha/fmha_bwd.cpp
example/ck_tile/01_fmha/fmha_bwd.cpp
+7
-2
example/ck_tile/01_fmha/fmha_bwd.hpp
example/ck_tile/01_fmha/fmha_bwd.hpp
+1
-0
example/ck_tile/01_fmha/hsaco/bwd_bf16_a16_rtz.cpp
example/ck_tile/01_fmha/hsaco/bwd_bf16_a16_rtz.cpp
+1998
-0
example/ck_tile/01_fmha/hsaco/bwd_bf16_causal_a16_rtz.cpp
example/ck_tile/01_fmha/hsaco/bwd_bf16_causal_a16_rtz.cpp
+2114
-0
example/ck_tile/01_fmha/hsaco/fmha_hsaco.hpp
example/ck_tile/01_fmha/hsaco/fmha_hsaco.hpp
+2
-0
example/ck_tile/01_fmha/script/smoke_test_bwd_xqa_ext.sh
example/ck_tile/01_fmha/script/smoke_test_bwd_xqa_ext.sh
+2
-2
No files found.
example/ck_tile/01_fmha/CMakeLists.txt
View file @
871c7556
...
@@ -67,7 +67,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
...
@@ -67,7 +67,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
# to be included in "make all/install/check"
# to be included in "make all/install/check"
message
(
"adding example
${
EXAMPLE_FMHA_BWD
}
"
)
message
(
"adding example
${
EXAMPLE_FMHA_BWD
}
"
)
add_executable
(
${
EXAMPLE_FMHA_BWD
}
EXCLUDE_FROM_ALL hsaco/bwd_bf16_a16.cpp hsaco/bwd_bf16_a32.cpp hsaco/bwd_bf16_causal_a16.cpp hsaco/bwd_bf16_causal_a32.cpp hsaco/bwd_bf16_nocoex_a32.cpp hsaco/bwd_bf16_nocoex_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_nocoex_a32.cpp hsaco/bwd_fp16_nocoex_causal_a32.cpp fmha_bwd.cpp
)
add_executable
(
${
EXAMPLE_FMHA_BWD
}
EXCLUDE_FROM_ALL hsaco/bwd_bf16_a16.cpp
hsaco/bwd_bf16_a16_rtz.cpp
hsaco/bwd_bf16_a32.cpp hsaco/bwd_bf16_causal_a16.cpp
hsaco/bwd_bf16_causal_a16_rtz.cpp
hsaco/bwd_bf16_causal_a32.cpp hsaco/bwd_bf16_nocoex_a32.cpp hsaco/bwd_bf16_nocoex_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_nocoex_a32.cpp hsaco/bwd_fp16_nocoex_causal_a32.cpp fmha_bwd.cpp
)
target_include_directories
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_include_directories
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_sources
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
FMHA_BWD_GEN_BLOBS
}
)
target_sources
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
FMHA_BWD_GEN_BLOBS
}
)
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
871c7556
...
@@ -620,11 +620,20 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -620,11 +620,20 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
}}
}}
}}
}}
else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
if(t.is_asm_rtz_cvt == true){{
const std::string bwd_ext_name = "bwd_ext_bf16_a16";
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
bool io_perm = a.nhead_stride_q > a.stride_q;
const std::string bwd_ext_name = "bwd_ext_bf16_a16_rtz";
r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16, bwd_ext_name, io_perm);
bool io_perm = a.nhead_stride_q > a.stride_q;
return r;
r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtz, bwd_ext_name, io_perm);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16, bwd_ext_name, io_perm);
return r;
}}
}}
}}
}}
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
...
@@ -648,11 +657,20 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
...
@@ -648,11 +657,20 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
}}
}}
}}
}}
else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
if(t.is_asm_rtz_cvt == true){{
const std::string bwd_ext_name = "bwd_ext_bf16_causal_a16";
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
bool io_perm = a.nhead_stride_q > a.stride_q;
const std::string bwd_ext_name = "bwd_ext_bf16_causal_a16_rtz";
r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16, bwd_ext_name, io_perm);
bool io_perm = a.nhead_stride_q > a.stride_q;
return r;
r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtz, bwd_ext_name, io_perm);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_causal_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16, bwd_ext_name, io_perm);
return r;
}}
}}
}}
}}
}}
}}
}}
...
...
example/ck_tile/01_fmha/fmha_bwd.cpp
View file @
871c7556
...
@@ -99,7 +99,10 @@ auto create_args(int argc, char* argv[])
...
@@ -99,7 +99,10 @@ auto create_args(int argc, char* argv[])
"if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) when ext_asm is set to 1"
)
"if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) when ext_asm is set to 1"
)
.
insert
(
"asm_no_coex"
,
.
insert
(
"asm_no_coex"
,
"0"
,
"0"
,
"if set to 1 will use non-coexectuion kernel when ext_asm is set to 1"
);
"if set to 1 will use non-coexectuion kernel when ext_asm is set to 1"
)
.
insert
(
"asm_rtz_cvt"
,
"0"
,
"if set to 1 will use float to bf16 RTZ convert when ext_asm is set to 1"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
@@ -191,6 +194,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -191,6 +194,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
bool
ext_asm
=
arg_parser
.
get_bool
(
"ext_asm"
);
bool
ext_asm
=
arg_parser
.
get_bool
(
"ext_asm"
);
bool
asm_atomic_fp32
=
arg_parser
.
get_bool
(
"asm_atomic_fp32"
);
bool
asm_atomic_fp32
=
arg_parser
.
get_bool
(
"asm_atomic_fp32"
);
bool
asm_no_coex
=
arg_parser
.
get_bool
(
"asm_no_coex"
);
bool
asm_no_coex
=
arg_parser
.
get_bool
(
"asm_no_coex"
);
bool
asm_rtz_cvt
=
arg_parser
.
get_bool
(
"asm_rtz_cvt"
);
ck_tile
::
stream_config
stream_config
{
nullptr
,
ck_tile
::
stream_config
stream_config
{
nullptr
,
true
,
true
,
...
@@ -428,7 +432,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -428,7 +432,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
deterministic
,
deterministic
,
ext_asm
,
ext_asm
,
asm_atomic_fp32
,
asm_atomic_fp32
,
asm_no_coex
};
asm_no_coex
,
asm_rtz_cvt
};
auto
fmha_args
=
[
&
]()
{
auto
fmha_args
=
[
&
]()
{
assert
(
nhead
%
nhead_k
==
0
);
assert
(
nhead
%
nhead_k
==
0
);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
...
...
example/ck_tile/01_fmha/fmha_bwd.hpp
View file @
871c7556
...
@@ -441,6 +441,7 @@ struct fmha_bwd_traits
...
@@ -441,6 +441,7 @@ struct fmha_bwd_traits
bool
uses_ext_asm
;
bool
uses_ext_asm
;
bool
is_asm_atomic_fp32
;
bool
is_asm_atomic_fp32
;
bool
is_asm_no_coex
;
bool
is_asm_no_coex
;
bool
is_asm_rtz_cvt
;
// TODO: padding check is inside this api
// TODO: padding check is inside this api
};
};
float
fmha_bwd
(
fmha_bwd_traits
,
fmha_bwd_args
,
const
ck_tile
::
stream_config
&
);
float
fmha_bwd
(
fmha_bwd_traits
,
fmha_bwd_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/01_fmha/hsaco/bwd_bf16_a16_rtz.cpp
0 → 100644
View file @
871c7556
This diff is collapsed.
Click to expand it.
example/ck_tile/01_fmha/hsaco/bwd_bf16_causal_a16_rtz.cpp
0 → 100644
View file @
871c7556
This source diff could not be displayed because it is too large. You can
view the blob
instead.
example/ck_tile/01_fmha/hsaco/fmha_hsaco.hpp
View file @
871c7556
...
@@ -4,8 +4,10 @@
...
@@ -4,8 +4,10 @@
#pragma once
#pragma once
extern
unsigned
char
bwd_bf16_a16
[];
extern
unsigned
char
bwd_bf16_a16
[];
extern
unsigned
char
bwd_bf16_a16_rtz
[];
extern
unsigned
char
bwd_bf16_a32
[];
extern
unsigned
char
bwd_bf16_a32
[];
extern
unsigned
char
bwd_bf16_causal_a16
[];
extern
unsigned
char
bwd_bf16_causal_a16
[];
extern
unsigned
char
bwd_bf16_causal_a16_rtz
[];
extern
unsigned
char
bwd_bf16_causal_a32
[];
extern
unsigned
char
bwd_bf16_causal_a32
[];
extern
unsigned
char
bwd_bf16_nocoex_a32
[];
extern
unsigned
char
bwd_bf16_nocoex_a32
[];
extern
unsigned
char
bwd_bf16_nocoex_causal_a32
[];
extern
unsigned
char
bwd_bf16_nocoex_causal_a32
[];
...
...
example/ck_tile/01_fmha/script/smoke_test_bwd_xqa_ext.sh
View file @
871c7556
...
@@ -13,8 +13,8 @@ for perm in 0 1 ; do
...
@@ -13,8 +13,8 @@ for perm in 0 1 ; do
for
hdim
in
128
;
do
for
hdim
in
128
;
do
for
mask
in
0 1
;
do
for
mask
in
0 1
;
do
$EXE
-prec
=
$prec
-b
=
2
-h
=
4
-h_k
=
2
-d
=
$hdim
-s
=
512
-iperm
=
$perm
-operm
=
$perm
-mask
=
$mask
-ext_asm
=
1
-asm_atomic_fp32
=
0
-mode
=
0
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
2
-h
=
4
-h_k
=
2
-d
=
$hdim
-s
=
512
-iperm
=
$perm
-operm
=
$perm
-mask
=
$mask
-ext_asm
=
1
-asm_atomic_fp32
=
0
-asm_rtz_cvt
=
1
-mode
=
0
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
1
-h
=
3
-h_k
=
1
-d
=
$hdim
-s
=
768
-iperm
=
$perm
-operm
=
$perm
-mask
=
$mask
-ext_asm
=
1
-asm_atomic_fp32
=
0
-mode
=
0
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
1
-h
=
3
-h_k
=
1
-d
=
$hdim
-s
=
768
-iperm
=
$perm
-operm
=
$perm
-mask
=
$mask
-ext_asm
=
1
-asm_atomic_fp32
=
0
-asm_rtz_cvt
=
1
-mode
=
0
-kname
=
$KNAME
$COMMON_ARGS
done
done
done
done
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment