Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
44f4498a
Commit
44f4498a
authored
Jun 14, 2023
by
guangzlu
Browse files
optimized code for fwd
parent
e00c308d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
61 deletions
+18
-61
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle_pt2.hpp
..._batched_multihead_attention_forward_xdl_cshuffle_pt2.hpp
+18
-61
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle_pt2.hpp
View file @
44f4498a
...
...
@@ -1128,26 +1128,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
acc_thread_buf
,
num_k_block_main_loop
);
// do MNK padding or upper triangular masking
if
constexpr
(
MaskOutUpperTriangle
||
PadN
)
{
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
// 8d block_desc in block scope
constexpr
auto
c_block_lengths
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
constexpr
auto
M1
=
c_block_lengths
[
I2
];
constexpr
auto
N1
=
c_block_lengths
[
I3
];
constexpr
auto
M2
=
c_block_lengths
[
I4
];
constexpr
auto
N2
=
c_block_lengths
[
I5
];
constexpr
auto
N3
=
c_block_lengths
[
I6
];
constexpr
auto
N4
=
c_block_lengths
[
I7
];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using
Acc0TileIterator
=
SpaceFillingCurve
<
...
...
@@ -1157,11 +1141,15 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
false
>
;
// SnakeCurved
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
))),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
z
M0
,
z
M1
,
z
M2
)),
make_unmerge_transform
(
make_tuple
(
z
N0
,
z
N1
,
z
N2
,
z
N3
,
z
N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
// do MNK padding or upper triangular masking
if
constexpr
(
MaskOutUpperTriangle
||
PadN
)
{
static_for
<
0
,
Acc0TileIterator
::
GetNumOfAccess
(),
1
>
{}([
&
](
auto
i
)
{
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
i
)
+
acc0_thread_origin
;
auto
m_local
=
...
...
@@ -1196,37 +1184,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
if
constexpr
(
IsDropout
)
// dropout
{
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
// 8d block_desc in block scope
constexpr
auto
c_block_lengths
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
constexpr
auto
M1
=
c_block_lengths
[
I2
];
constexpr
auto
N1
=
c_block_lengths
[
I3
];
constexpr
auto
M2
=
c_block_lengths
[
I4
];
constexpr
auto
N2
=
c_block_lengths
[
I5
];
constexpr
auto
N3
=
c_block_lengths
[
I6
];
constexpr
auto
N4
=
c_block_lengths
[
I7
];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using
Acc0TileIterator
=
SpaceFillingCurve
<
decltype
(
c_thread_lengths
),
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
;
// SnakeCurved
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment