Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
91e1a796
Unverified
Commit
91e1a796
authored
Jan 22, 2025
by
Max Podkorytov
Browse files
move elementwise functors
parent
052a7265
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
8 deletions
+25
-8
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+25
-8
No files found.
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
91e1a796
...
...
@@ -1302,24 +1302,33 @@ struct FmhaFwdKernel
}
}();
constexpr
auto
q_arg_element_func
=
identity
{};
constexpr
auto
k_arg_element_func
=
identity
{};
constexpr
auto
v_arg_element_func
=
identity
{};
constexpr
auto
bias_arg_element_func
=
identity
{};
constexpr
auto
lse_arg_element_func
=
identity
{};
constexpr
auto
s_acc_arg_element_func
=
identity
{};
constexpr
auto
p_compute_arg_element_func
=
identity
{};
constexpr
auto
o_acc_arg_element_func
=
identity
{};
auto
o_acc_tile
=
[
&
]()
{
if
constexpr
(
kDoFp8StaticQuant
)
{
return
FmhaPipeline
{}(
q_dram_window
,
identity
{},
// q
_element_func
q_arg
_element_func
,
k_dram_window
,
identity
{},
// k
_element_func
k_arg
_element_func
,
v_dram_window
,
identity
{},
// v
_element_func
v_arg
_element_func
,
bias_dram_window
,
identity
{},
//
bias_element_func
bias
_arg
_element_func
,
randval_dram_window
,
lse_dram_window
,
identity
{}
,
// lse_element_func
identity
{}
,
// s_acc_element_func
scales
{
kargs
.
scale_p
},
// p_compute_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
lse_arg_element_func
,
s_acc_arg_element_func
,
composes
(
p_compute_arg_element_func
,
scales
{
kargs
.
scale_p
}
)
,
composes
(
o_acc_arg_element_func
,
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
mask
,
position_encoding
,
kargs
.
scale_s
,
...
...
@@ -1329,11 +1338,19 @@ struct FmhaFwdKernel
else
{
return
FmhaPipeline
{}(
q_dram_window
,
q_arg_element_func
,
k_dram_window
,
k_arg_element_func
,
v_dram_window
,
v_arg_element_func
,
bias_dram_window
,
bias_arg_element_func
,
randval_dram_window
,
lse_dram_window
,
lse_arg_element_func
,
s_acc_arg_element_func
,
p_compute_arg_element_func
,
o_acc_arg_element_func
,
mask
,
position_encoding
,
kargs
.
scale_s
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment