Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3e75a4dd
Commit
3e75a4dd
authored
Feb 04, 2025
by
Aviral Goel
Browse files
Merge remote-tracking branch 'upstream/ck-flex' into aviralgoel-amd-jenkins
to fix accuracy errors
parents
9248f595
2c8e04aa
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
3 deletions
+5
-3
example/ck_tile/18_flexattn/CMakeLists.txt
example/ck_tile/18_flexattn/CMakeLists.txt
+3
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs.hpp
...e/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs.hpp
+1
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs_async.hpp
...fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs_async.hpp
+1
-1
No files found.
example/ck_tile/18_flexattn/CMakeLists.txt
View file @
3e75a4dd
...
@@ -8,9 +8,11 @@ endif()
...
@@ -8,9 +8,11 @@ endif()
variable_watch
(
FMHA_SCORE_MOD_F
)
variable_watch
(
FMHA_SCORE_MOD_F
)
set
(
FMHA_SCORE_MOD_F
[[s + static_cast<decltype(s)>((q_idx - v_idx) % 8)]]
)
set
(
FMHA_SCORE_MOD_F
[[s + static_cast<decltype(s)>((q_idx - v_idx) % 8)]]
)
# set(FMHA_SCORE_MOD_F [[s]])
variable_watch
(
FMHA_PRE_SOFTMAX_F
)
variable_watch
(
FMHA_PRE_SOFTMAX_F
)
set
(
FMHA_PRE_SOFTMAX_F
[[static_cast<decltype(s)>(tanh(s*1.0)/1.0)]]
)
# set(FMHA_PRE_SOFTMAX_F [[static_cast<decltype(s)>(tanh(s*1.0)/1.0)]])
set
(
FMHA_PRE_SOFTMAX_F
[[s]]
)
foreach
(
api
${
FMHA_FWD_ENABLE_APIS
}
)
foreach
(
api
${
FMHA_FWD_ENABLE_APIS
}
)
if
(
NOT
"
${
api
}
"
IN_LIST FMHA_FWD_KNOWN_APIS
)
if
(
NOT
"
${
api
}
"
IN_LIST FMHA_FWD_KNOWN_APIS
)
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs.hpp
View file @
3e75a4dd
...
@@ -350,7 +350,7 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -350,7 +350,7 @@ struct BlockFmhaPipelineQRKSVS
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
1
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
=
score_mod
(
s_acc
(
i_j_idx
),
row
,
col
);
s_acc
(
i_j_idx
)
=
score_mod
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs_async.hpp
View file @
3e75a4dd
...
@@ -420,7 +420,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -420,7 +420,7 @@ struct BlockFmhaPipelineQRKSVSAsync
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
1
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
=
score_mod
(
s_acc
(
i_j_idx
),
row
,
col
);
s_acc
(
i_j_idx
)
=
score_mod
(
s_acc
(
i_j_idx
),
row
,
col
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment