Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c07c2b55
Commit
c07c2b55
authored
Jun 20, 2023
by
danyao12
Browse files
continue fwd cleanup
parent
dad06b35
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
45 deletions
+45
-45
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
...ion/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
+45
-45
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
c07c2b55
...
@@ -882,19 +882,19 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -882,19 +882,19 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
n4
));
// registerNum
n4
));
// registerNum
constexpr
auto
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4
=
// for blockwise copy
constexpr
auto
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4
=
// for blockwise copy
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
//
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
//
MRepeat
n0
,
//
n0
,
//
NRepeat
m1
,
//
m1
,
//
MWaveId
n1
,
//
n1
,
//
NWaveId
m2
,
//
m0 1
m2
,
//
MPerXdl
n2
,
//
n0 4
n2
,
//
NGroupNum
n3
,
//
n1 1
n3
,
//
NInputNum
n4
,
//
m1 4
n4
,
//
registerNum
I1
));
//
n2
1
I1
));
//
I
1
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockI
D
I1
,
// NBlockI
d
m0
,
// MRepeat
m0
,
// MRepeat
n0
,
// NRepeat
n0
,
// NRepeat
m1
,
// MWaveId
m1
,
// MWaveId
...
@@ -907,26 +907,26 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -907,26 +907,26 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
constexpr
auto
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
constexpr
auto
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
z
M0
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
Z
M0
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
z
N0
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
Z
N0
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
z
M1
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
Z
M1
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
z
N1
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
constexpr
auto
Z
N1
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
constexpr
auto
z
M2
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
constexpr
auto
Z
M2
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
constexpr
auto
z
N2
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
constexpr
auto
Z
N2
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
constexpr
auto
z
N3
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
Z
N3
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
z
N4
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
Z
N4
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4
=
constexpr
auto
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
make_pass_through_transform
(
z
M0
),
make_tuple
(
make_pass_through_transform
(
Z
M0
),
make_pass_through_transform
(
z
N0
),
make_pass_through_transform
(
Z
N0
),
make_pass_through_transform
(
z
M1
),
make_pass_through_transform
(
Z
M1
),
make_pass_through_transform
(
z
N1
),
make_pass_through_transform
(
Z
N1
),
make_unmerge_transform
(
make_tuple
(
Number
<
zM2
.
value
/
z
N4
.
value
>
{}
,
z
N4
)),
make_unmerge_transform
(
make_tuple
(
ZM2
/
Z
N4
,
Z
N4
)),
make_pass_through_transform
(
z
N2
),
make_pass_through_transform
(
Z
N2
),
make_pass_through_transform
(
z
N3
),
make_pass_through_transform
(
Z
N3
),
make_pass_through_transform
(
z
N4
)),
make_pass_through_transform
(
Z
N4
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
...
@@ -955,7 +955,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -955,7 +955,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
auto
z_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
z_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ZDataType
*>
(
p_shared
),
static_cast
<
ushort
*>
(
p_shared
),
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetElementSpaceSize
());
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetElementSpaceSize
());
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_id
=
GetGemm0WaveIdx
();
...
@@ -974,7 +974,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -974,7 +974,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
m2
,
// MPerXdl
m2
,
// MPerXdl
n2
,
// NGroupNum
n2
,
// NGroupNum
n3
,
// NInputNum
n3
,
// NInputNum
n4
>
,
n4
>
,
// registerNum
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
// DstVectorDim
7
,
// DstVectorDim
1
,
// DstScalarPerVector
1
,
// DstScalarPerVector
...
@@ -1007,11 +1007,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1007,11 +1007,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
0
,
// nrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_id
[
I1
],
// NWaveId
int
(
wave_m_n_id
[
I1
]
/
4
),
// MPerXdl
wave_m_n_id
[
I1
]
/
ZN4
,
0
,
// group
0
,
wave_m_n_id
[
I0
],
// NInputIndex
wave_m_n_id
[
I0
],
0
,
0
,
wave_m_n_id
[
I1
]
%
4
)};
wave_m_n_id
[
I1
]
%
ZN
4
)};
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushort
,
...
@@ -1091,8 +1091,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1091,8 +1091,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
false
>
;
// SnakeCurved
false
>
;
// SnakeCurved
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
z
M0
,
z
M1
,
z
M2
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Z
M0
,
Z
M1
,
Z
M2
)),
make_unmerge_transform
(
make_tuple
(
z
N0
,
z
N1
,
z
N2
,
z
N3
,
z
N4
))),
make_unmerge_transform
(
make_tuple
(
Z
N0
,
Z
N1
,
Z
N2
,
Z
N3
,
Z
N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment