Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5ab137f4
Commit
5ab137f4
authored
Sep 19, 2024
by
danyao12
Browse files
add traits
parent
a0491b67
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
37 deletions
+47
-37
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+33
-31
example/ck_tile/01_fmha/fmha_bwd.cpp
example/ck_tile/01_fmha/fmha_bwd.cpp
+12
-6
example/ck_tile/01_fmha/fmha_bwd.hpp
example/ck_tile/01_fmha/fmha_bwd.hpp
+2
-0
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
5ab137f4
...
@@ -333,38 +333,40 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
...
@@ -333,38 +333,40 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
float r = -1;
float r = -1;
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
if (t.uses_ext_asm == true){{
(a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 128 == 0) && (a.hdim_q == 128) && (a.hdim_v == 128) && (t.is_deterministic == false)) {{
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
if(t.data_type.compare("fp16") == 0){{
(a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 128 == 0) && (a.hdim_q == 128) && (a.hdim_v == 128) && (t.is_deterministic == false)) {{
if(t.mask_type == mask_enum::no_mask){{
if(t.data_type.compare("fp16") == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
if(t.mask_type == mask_enum::no_mask){{
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_a32";
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_ext_name);
const std::string bwd_ext_name = "bwd_ext_fp16_a32";
return r;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_ext_name);
}}
return r;
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
}}
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_causal_a32";
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_causal_a32, bwd_ext_name);
const std::string bwd_ext_name = "bwd_ext_fp16_causal_a32";
return r;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_causal_a32, bwd_ext_name);
return r;
}}
}}
}}
}}
else if(t.data_type.compare("bf16") == 0){{
else
if(t.
data
_type
.compare("bf16") == 0
){{
if(t.
mask
_type
== mask_enum::no_mask
){{
if(t.mask_type == mask_enum::no_mask){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using
dot_do_o
_trait_ = fmha_bwd_
dot_do_o
_traits_<128, ck_tile::bf16_t, false, false, false>;
using
convert_dq
_trait_ = fmha_bwd_
convert_dq
_traits_<128, ck_tile::bf16_t, false, false,
false,
false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>
;
const std::string bwd_ext_name = "bwd_ext_bf16_a32"
;
const std::string bwd_ext_name = "bwd_ext_bf16_a32"
;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32, bwd_ext_name)
;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32, bwd_ext_name)
;
return r
;
return r;
}}
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using
dot_do_o
_trait_ = fmha_bwd_
dot_do_o
_traits_<128, ck_tile::bf16_t, false, false, false>;
using
convert_dq
_trait_ = fmha_bwd_
convert_dq
_traits_<128, ck_tile::bf16_t, false, false,
false,
false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>
;
const std::string bwd_ext_name = "bwd_ext_bf16_causal_a32"
;
const std::string bwd_ext_name = "bwd_ext
_bf16_causal_a32
"
;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd
_bf16_causal_a32
, bwd_ext_name)
;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32, bwd_ext_name)
;
return r
;
return r;
}}
}}
}}
}}
}}
}}
}}
...
...
example/ck_tile/01_fmha/fmha_bwd.cpp
View file @
5ab137f4
...
@@ -91,7 +91,9 @@ auto create_args(int argc, char* argv[])
...
@@ -91,7 +91,9 @@ auto create_args(int argc, char* argv[])
.
insert
(
"deterministic"
,
.
insert
(
"deterministic"
,
"0"
,
"0"
,
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
"will not be used"
);
"will not be used"
)
.
insert
(
"ext_asm"
,
"0"
,
"if set to 1, some cases will call the ext asm dqdkdv kernel"
)
.
insert
(
"asm_atomic_fp32"
,
"1"
,
"if set to 0, atomic fp16/bf16 is used when calling the ext asm dqdkdv kernel"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
@@ -176,10 +178,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -176,10 +178,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
seed
.
reset
();
seed
.
reset
();
}
}
int
stream_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
stream_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
stream_repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
stream_repeat
=
arg_parser
.
get_int
(
"repeat"
);
bool
kname
=
arg_parser
.
get_bool
(
"kname"
);
bool
kname
=
arg_parser
.
get_bool
(
"kname"
);
bool
deterministic
=
arg_parser
.
get_bool
(
"deterministic"
);
bool
deterministic
=
arg_parser
.
get_bool
(
"deterministic"
);
bool
ext_asm
=
arg_parser
.
get_bool
(
"ext_asm"
);
bool
asm_atomic_fp32
=
arg_parser
.
get_bool
(
"asm_atomic_fp32"
);
ck_tile
::
stream_config
stream_config
{
nullptr
,
ck_tile
::
stream_config
stream_config
{
nullptr
,
true
,
true
,
...
@@ -416,7 +420,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -416,7 +420,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
use_dbias
,
use_dbias
,
p_drop
>
0.0
f
,
p_drop
>
0.0
f
,
s_randval
,
s_randval
,
deterministic
};
deterministic
,
ext_asm
,
asm_atomic_fp32
};
auto
fmha_args
=
[
&
]()
{
auto
fmha_args
=
[
&
]()
{
assert
(
nhead
%
nhead_k
==
0
);
assert
(
nhead
%
nhead_k
==
0
);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
...
...
example/ck_tile/01_fmha/fmha_bwd.hpp
View file @
5ab137f4
...
@@ -438,6 +438,8 @@ struct fmha_bwd_traits
...
@@ -438,6 +438,8 @@ struct fmha_bwd_traits
bool
has_dropout
;
bool
has_dropout
;
bool
is_store_randval
;
bool
is_store_randval
;
bool
is_deterministic
;
bool
is_deterministic
;
bool
uses_ext_asm
;
bool
is_asm_atomic_fp32
;
// TODO: padding check is inside this api
// TODO: padding check is inside this api
};
};
float
fmha_bwd
(
fmha_bwd_traits
,
fmha_bwd_args
,
const
ck_tile
::
stream_config
&
);
float
fmha_bwd
(
fmha_bwd_traits
,
fmha_bwd_args
,
const
ck_tile
::
stream_config
&
);
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment