Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
2a198f14
Unverified
Commit
2a198f14
authored
Jan 28, 2025
by
Max Podkorytov
Browse files
add a hardcoded score_mod
parent
cd69c852
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
66 additions
and
2 deletions
+66
-2
example/ck_tile/18_flexattn/bias.hpp
example/ck_tile/18_flexattn/bias.hpp
+1
-1
example/ck_tile/18_flexattn/mask.hpp
example/ck_tile/18_flexattn/mask.hpp
+1
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs.hpp
...e/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs.hpp
+32
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs_async.hpp
...fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs_async.hpp
+32
-0
No files found.
example/ck_tile/18_flexattn/bias.hpp
View file @
2a198f14
...
...
@@ -6,7 +6,7 @@
#include <ostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/
flex_
fmha.hpp"
// keep sync with BlockAttentionBiasEnum
enum
class
bias_enum
...
...
example/ck_tile/18_flexattn/mask.hpp
View file @
2a198f14
...
...
@@ -7,7 +7,7 @@
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/
flex_
fmha.hpp"
// keep this in sync with ck_tile::GenericAttentionMaskEnum
enum
class
mask_enum
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs.hpp
View file @
2a198f14
...
...
@@ -337,6 +337,38 @@ struct BlockFmhaPipelineQRKSVS
}
// STAGE 2, scale_s, add bias, mask, softmax
// FlexAttention score modifier operates directly on scores
{
auto
score_mod
=
[](
auto
s
,
ck_tile
::
index_t
b
,
ck_tile
::
index_t
h
,
ck_tile
::
index_t
q_idx
,
ck_tile
::
index_t
v_idx
)
{
(
void
)
s
;
(
void
)
b
;
(
void
)
h
;
return
static_cast
<
decltype
(
s
)
>
(
q_idx
-
v_idx
);
};
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
b
=
0
;
const
auto
h
=
0
;
s_acc
(
i_j_idx
)
=
score_mod
(
s_acc
(
i_j_idx
),
b
,
h
,
row
,
col
);
});
});
}
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs_async.hpp
View file @
2a198f14
...
...
@@ -407,6 +407,38 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier
(
1
);
// STAGE 2, scale_s, add bias, mask, softmax
// FlexAttention score modifier operates directly on scores
{
auto
score_mod
=
[](
auto
s
,
ck_tile
::
index_t
b
,
ck_tile
::
index_t
h
,
ck_tile
::
index_t
q_idx
,
ck_tile
::
index_t
v_idx
)
{
(
void
)
s
;
(
void
)
b
;
(
void
)
h
;
return
static_cast
<
decltype
(
s
)
>
(
q_idx
-
v_idx
);
};
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
b
=
0
;
const
auto
h
=
0
;
s_acc
(
i_j_idx
)
=
score_mod
(
s_acc
(
i_j_idx
),
b
,
h
,
row
,
col
);
});
});
}
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment