Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fced127d
"vscode:/vscode.git/clone" did not exist on "5f0c1a1c0dba914ac0b1e9838b94dd022d512aca"
Commit
fced127d
authored
Jan 30, 2023
by
danyao12
Browse files
only read dO once, reduce data reading from HBM
parent
84f162f9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
73 additions
and
70 deletions
+73
-70
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+73
-70
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
100644 → 100755
View file @
fced127d
...
...
@@ -309,13 +309,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
static
constexpr
auto
ygrad_block_desc_k0_m_k1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
//
//
A matrix in LDS memory, dst of blockwise copy
//
static constexpr auto a_block_desc_ak0_m_ak1 =
//
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
//
//
B matrix in LDS memory, dst of blockwise copy
//
static constexpr auto b_block_desc_bk0_n_bk1 =
//
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
template
<
typename
GridDesc_K0_M_K1
>
using
QBlockwiseCopy
=
...
...
@@ -1002,30 +1002,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
Number
<
BlockSliceLength_M_
*
ThreadClusterLength_O
/
BlockSize_
>
{};
// dY matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
ygrad_block_desc_
k
0_m_
k
1
=
static
constexpr
auto
ygrad_block_desc_
o
0_m_
o
1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
__host__
__device__
static
constexpr
auto
MakeYGradBlockDesc_M
0_K0_M1_K1
()
__host__
__device__
static
constexpr
auto
MakeYGradBlockDesc_M
_O
()
{
const
auto
K0_
=
ygrad_block_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
M_
=
ygrad_block_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
K1_
=
ygrad_block_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
O0_
=
ygrad_block_desc_o0_m_o1
.
GetLength
(
I0
);
const
auto
M_
=
ygrad_block_desc_o0_m_o1
.
GetLength
(
I1
);
const
auto
O1_
=
ygrad_block_desc_o0_m_o1
.
GetLength
(
I2
);
static_assert
(
O0_
*
O1_
==
BlockSliceLength_O_
,
""
);
static_assert
(
M_
==
BlockSliceLength_M_
,
""
);
constexpr
auto
ygrad_block_desc_k_m
=
transform_tensor_descriptor
(
//(
64, 128
)
ygrad_block_desc_
k
0_m_
k
1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
K
0_
,
K
1_
)),
//(8, 8)
return
transform_tensor_descriptor
(
//(
128, 64
)
ygrad_block_desc_
o
0_m_
o
1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
O
0_
,
O
1_
)),
//(8, 8)
make_pass_through_transform
(
M_
)),
//128
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
return
transform_tensor_descriptor
(
//(32, 8, 4, 8)
ygrad_block_desc_k_m
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
ThreadClusterLength_O
,
ThreadSliceLength_O
)),
//(8, 8)
make_unmerge_transform
(
make_tuple
(
ThreadClusterLength_M
,
ThreadSliceLength_M
))),
//(32, 4)
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
,
3
>
{},
Sequence
<
0
,
2
>
{}));
}
static
constexpr
auto
ygrad_block_desc_m_o
=
MakeYGradBlockDesc_M_O
();
static_assert
(
ThreadClusterLength_O
*
ThreadSliceLength_O
==
BlockSliceLength_O_
,
""
);
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
...
...
@@ -1127,8 +1126,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
reduction_space_offset
=
0
;
// static constexpr auto reduction_space_offset = ygrad_block_space_size_aligned.value + q_block_space_size_aligned.value;
static
constexpr
auto
reduction_space_offset
=
(
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
)
*
sizeof
(
DataType
)
/
sizeof
(
FloatGemmAcc
);
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
...
...
@@ -1543,6 +1541,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
constexpr
auto
y_thread_desc_m0_m1_o0_o1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_O
));
constexpr
auto
ygrad_thread_desc_m_o
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
YDotYGrad_M_O
::
ThreadSliceLength_M
,
YDotYGrad_M_O
::
ThreadSliceLength_O
));
constexpr
auto
y_thread_cluster_desc
=
make_cluster_descriptor
(
Sequence
<
I1
,
YDotYGrad_M_O
::
ThreadClusterLength_M
,
...
...
@@ -1552,15 +1552,23 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
const
auto
y_thread_cluster_idx
=
y_thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
constexpr
auto
ygrad_thread_cluster_desc
=
make_cluster_descriptor
(
Sequence
<
YDotYGrad_M_O
::
ThreadClusterLength_M
,
YDotYGrad_M_O
::
ThreadClusterLength_O
>
{},
Sequence
<
0
,
1
>
{});
const
auto
ygrad_thread_cluster_idx
=
ygrad_thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
y_thread_data_on_block_idx
=
y_thread_cluster_idx
*
y_thread_desc_m0_m1_o0_o1
.
GetLengths
();
const
auto
ygrad_thread_data_on_block_idx
=
ygrad_thread_cluster_idx
*
ygrad_thread_desc_m_o
.
GetLengths
();
const
auto
y_thread_data_on_grid_idx
=
make_multi_index
(
block_work_idx
[
I0
],
I0
,
I0
/* all WGs start from o_block_idx = 0 */
,
I0
)
+
y_thread_data_on_block_idx
;
// performs
double duty for both y and ygrad
auto
y
ygrad
_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
// performs
for y
auto
y_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
...
...
@@ -1574,26 +1582,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
true
/* InvalidElementAsNaN */
>
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_thread_data_on_grid_idx
);
//
//
performs for ygrad
//
auto ygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
//
DataType,
//
DataType,
//
YBlockDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
//
decltype(y_thread_desc_m
0_m1_o0_o1
),
//
decltype(y_thread_desc_m
0_m1_o0_o1
.GetLengths()),
//
Sequence<0, 1
, 2, 3
>,
//
3
, // SrcVectorDim
//
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
//
1, // SrcScalarStrideInVector
//
true /* ResetCoordAfterRun */,
//
true /* InvalidElementAsNaN */>(
y_block_desc_mblock_mperblock_oblock_operblock
,
//
y_thread_data_on_block_idx);
// performs for ygrad
auto
ygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
decltype
(
YDotYGrad_M_O
::
ygrad_block_desc_m_o
)
,
decltype
(
y
grad
_thread_desc_m
_o
),
decltype
(
y
grad
_thread_desc_m
_o
.
GetLengths
()),
Sequence
<
0
,
1
>
,
1
,
// SrcVectorDim
YDotYGrad_M_O
::
SrcScalarPerVector
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
true
/* ResetCoordAfterRun */
,
true
/* InvalidElementAsNaN */
>
(
YDotYGrad_M_O
::
ygrad_block_desc_m_o
,
y
grad
_thread_data_on_block_idx
);
auto
y_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
ygrad_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
y_dot_ygrad_thread_accum_buf
=
typename
YDotYGrad_M_O
::
DstBufType
{};
auto
y_dot_ygrad_block_accum_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatGemmAcc
*>
(
p_shared
),
MPerBlock
);
static_cast
<
FloatGemmAcc
*>
(
p_shared
)
+
SharedMemTrait
::
reduction_space_offset
,
MPerBlock
);
constexpr
auto
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
=
make_naive_tensor_descriptor
(
make_tuple
(
I1
,
P_M0
,
P_M1
,
P_M2
),
...
...
@@ -1622,6 +1630,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
auto
y_dot_ygrad_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
>
(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
.
GetElementSpaceSize
());
// load ygrad
gemm_tile_ygrad_blockwise_copy
.
Run
(
ygrad_grid_desc_o0_m_o1
,
ygrad_grid_buf
,
GemmBlockwiseCopy
::
ygrad_block_desc_k0_m_k1
,
ygrad_block_buf
,
I0
);
block_sync_lds
();
//
// calculate Y dot dY
//
...
...
@@ -1630,34 +1643,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
y_dot_ygrad_thread_accum_buf
.
Clear
();
y_dot_ygrad_block_accum_buf
.
Clear
();
index_t
oblock_idx
=
0
;
do
{
yygrad_threadwise_copy
.
Run
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_grid_buf
,
y_thread_desc_m0_m1_o0_o1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
y_thread_buf
);
yygrad_threadwise_copy
.
Run
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
ygrad_grid_buf
,
y_thread_desc_m0_m1_o0_o1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
ygrad_thread_buf
);
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_O
,
1
>
{}([
&
](
auto
iO
)
{
constexpr
auto
offset
=
y_thread_desc_m0_m1_o0_o1
.
CalculateOffset
(
make_multi_index
(
I0
,
iM
,
I0
,
iO
));
y_dot_ygrad_thread_accum_buf
(
iM
)
+=
y_thread_buf
[
Number
<
offset
>
{}]
*
ygrad_thread_buf
[
Number
<
offset
>
{}];
});
});
yygrad_threadwise_copy
.
MoveSrcSliceWindow
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
make_multi_index
(
0
,
0
,
1
,
0
));
y_threadwise_copy
.
Run
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_grid_buf
,
y_thread_desc_m0_m1_o0_o1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
y_thread_buf
);
ygrad_threadwise_copy
.
Run
(
YDotYGrad_M_O
::
ygrad_block_desc_m_o
,
ygrad_block_buf
,
ygrad_thread_desc_m_o
,
make_tuple
(
I0
,
I0
),
ygrad_thread_buf
);
oblock_idx
++
;
}
while
(
oblock_idx
<
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetLength
(
I2
));
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_O
,
1
>
{}([
&
](
auto
iO
)
{
constexpr
auto
y_offset
=
y_thread_desc_m0_m1_o0_o1
.
CalculateOffset
(
make_multi_index
(
I0
,
iM
,
I0
,
iO
));
constexpr
auto
ygrad_offset
=
ygrad_thread_desc_m_o
.
CalculateOffset
(
make_multi_index
(
iM
,
iO
));
y_dot_ygrad_thread_accum_buf
(
iM
)
+=
y_thread_buf
[
Number
<
y_offset
>
{}]
*
ygrad_thread_buf
[
Number
<
ygrad_offset
>
{}];
});
});
// blockwise reduction using atomic_add
block_sync_lds
();
...
...
@@ -1691,9 +1697,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
// load q
gemm_tile_q_blockwise_copy
.
Run
(
q_grid_desc_k0_m_k1
,
q_grid_buf
,
GemmBlockwiseCopy
::
q_block_desc_k0_m_k1
,
q_block_buf
,
I0
);
// load ygrad
gemm_tile_ygrad_blockwise_copy
.
Run
(
ygrad_grid_desc_o0_m_o1
,
ygrad_grid_buf
,
GemmBlockwiseCopy
::
ygrad_block_desc_k0_m_k1
,
ygrad_block_buf
,
I0
);
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
do
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment