Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
48a16339
Commit
48a16339
authored
Aug 23, 2023
by
letaoqin
Browse files
load D0 data only for 4 xdl
parent
ab0d58b2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
74 additions
and
69 deletions
+74
-69
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+74
-69
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
48a16339
...
...
@@ -1155,8 +1155,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// D0
static
constexpr
auto
D0M2
=
Number
<
4
>
{};
static
constexpr
auto
D0M1
=
Number
<
MPer
Block
>
{}
/
D0M2
;
//
static constexpr auto D0M = Number<MPerBlock>{} /
D0M2
;
static
constexpr
auto
D0M1
=
Number
<
MPer
Xdl
>
{}
/
D0M2
;
static
constexpr
auto
D0M
0
=
Number
<
MPerBlock
>
{}
/
Number
<
MPerXdl
>
{}
;
__host__
__device__
static
constexpr
auto
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
...
...
@@ -1169,10 +1169,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
auto
d0_grid_desc_m0_n0_m1_m2_n1_m3
=
transform_tensor_descriptor
(
d0_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
D0M1
,
D0M2
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
D0M0
,
D0M1
,
D0M2
)),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
3
,
5
>
{},
Sequence
<
1
,
4
>
{}));
return
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
}
...
...
@@ -1188,8 +1188,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__
__device__
static
constexpr
auto
GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3
()
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
I1
,
I1
,
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
),
return
make_naive_tensor_descriptor
(
make_tuple
(
I1
,
I1
,
I1
,
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
),
make_tuple
(
Number
<
NPerBlock
>
{}
*
D0M2
,
Number
<
NPerBlock
>
{}
*
D0M2
,
Number
<
NPerBlock
>
{}
*
D0M2
,
Number
<
NPerBlock
>
{}
*
D0M2
,
D0M2
,
...
...
@@ -1216,24 +1218,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3
();
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I
8
,
I1
,
D0M2
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I
4
,
I1
,
D0M2
));
using
D0BlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
I1
,
I1
,
D0M1
,
NPerBlock
,
D0M2
>
,
// BlockSliceLengths
Sequence
<
1
,
1
,
8
,
32
,
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
4
,
3
>
,
// ThreadClusterArrangeOrder
Sequence
<
I1
,
I1
,
I1
,
D0M1
,
NPerBlock
,
D0M2
>
,
// BlockSliceLengths
Sequence
<
1
,
1
,
1
,
8
,
32
,
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
D0DataType
,
// SrcData
D0DataType
,
// DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
decltype
(
d0_block_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
Sequence
<
0
,
1
,
2
,
4
,
3
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
3
,
2
,
4
>
,
// DstDimAccessOrder
3
,
// SrcVectorDim
4
,
// DstVectorDim
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
4
,
// SrcVectorDim
5
,
// DstVectorDim
D0BlockTransferSrcScalarPerVector
,
// SrcScalarPerVector
4
,
// DstScalarPerVector
1
,
...
...
@@ -1247,7 +1249,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
D0DataType
,
// DstData
decltype
(
d0_block_desc_n0_n1_m0_m1_m2_m3
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
8
,
1
,
4
>
,
// SliceLengths
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
4
,
// SrcVectorDim
2
,
// SrcScalarPerVector
...
...
@@ -1833,10 +1835,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// D0
auto
d0_block_copy_global_to_lds
=
typename
D0Loader
::
D0BlockwiseCopy
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
),
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
D0Loader
::
d0_block_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0ThreadCopy
(
...
...
@@ -2002,8 +2004,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// add bias
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
//
static constexpr auto c_thread_desc_ =
//
make_naive_tensor_descriptor_packed(make_tuple(
I2
, Number<16>{}));
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
D0M0
,
Number
<
16
>
{}));
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
...
...
@@ -2016,11 +2018,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
D0Loader
::
d0_thread_desc_
.
GetElementSpaceSize
());
ignore
=
d0_thread_buf
;
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
// load data to lds
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_buf
);
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_buf
);
//
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
//
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(
1
, 0, 0, 0, 0));
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
RunWrite
(
D0Loader
::
d0_block_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
...
...
@@ -2032,11 +2036,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
D0Loader
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
// static_for<0, 2, 1>{}([&](auto mr) {
// bias add
static_for
<
0
,
s_slash_p
_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
//
constexpr index_t c_offset =
//
c_thread_desc_.CalculateOffset(make_tuple(mr, i));
static_for
<
0
,
d0
_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
mr
,
i
));
// if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
// if(ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]) != 1.0f)
...
...
@@ -2054,12 +2058,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]));
// }
s_slash_p_thread_buf
(
i
)
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
s_slash_p_thread_buf
(
Number
<
c_offset
>
{})
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
});
});
//});
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
// P_i: = softmax(scalar * S_i:)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment