Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
118742b6
Commit
118742b6
authored
Jun 21, 2023
by
ltqin
Browse files
rewrite code
parent
33fad9ba
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
64 additions
and
137 deletions
+64
-137
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v4.cpp
..._softmax_gemm/batched_multihead_attention_backward_v4.cpp
+16
-16
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v6.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v6.hpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v7.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v7.hpp
+7
-7
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_bacckward_ydotygrad.hpp
...dwise_batched_multihead_attention_bacckward_ydotygrad.hpp
+33
-108
include/ck/utility/get_id.hpp
include/ck/utility/get_id.hpp
+2
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v4.cpp
View file @
118742b6
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v6.hpp
View file @
118742b6
...
...
@@ -777,18 +777,18 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
YGridDesc_M_O
,
ORSGridDesc_M
,
BlockSize
,
MPerBlock
,
NPerBlock
,
128
,
128
,
KPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
AK1
,
BK1
,
B1K1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
32
,
32
,
1
,
4
,
ABlockLdsExtraM
,
BBlockLdsExtraN
,
Deterministic
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v7.hpp
View file @
118742b6
...
...
@@ -793,18 +793,18 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
YGridDesc_M_O
,
ORSGridDesc_M
,
BlockSize
,
MPerBlock
,
NPerBlock
,
256
,
128
,
KPerBlock
,
Gemm1NPerBlock
,
32
,
Gemm1KPerBlock
,
AK1
,
BK1
,
B1K1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
64
,
64
,
1
,
4
,
ABlockLdsExtraM
,
BBlockLdsExtraN
,
Deterministic
>
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_bacckward_ydotygrad.hpp
View file @
118742b6
...
...
@@ -127,14 +127,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{
const
index_t
M
=
lse_grid_desc_m
.
GetLength
(
I0
);
const
index_t
MBlock
=
M
/
MPerBlock
;
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
const
auto
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
=
transform_tensor_descriptor
(
lse_grid_desc_m
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MXdlPerWave
>
{},
MWave
,
Number
<
MPerXdl
>
{}))),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
,
1
>
{}));
return
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
;
}
...
...
@@ -214,13 +212,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
{
static_assert
(
BlockSize_
==
BlockSliceLength_M_
);
static
constexpr
auto
ThreadSliceLength_M
=
Number
<
1
>
{};
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
InputDataType
);
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
1
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
static
constexpr
auto
ThreadSliceLength_O
=
Number
<
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadSliceLength_M
=
Number
<
BlockSliceLength_M_
*
ThreadClusterLength_O
/
BlockSize_
>
{};
static
constexpr
auto
ThreadSliceLength_O
=
Number
<
BlockSliceLength_O_
>
{};
static_assert
(
ThreadClusterLength_O
*
ThreadSliceLength_O
==
BlockSliceLength_O_
,
""
);
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
...
...
@@ -274,32 +271,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const
index_t
block_work_idx_m
=
Deterministic
?
block_idx_m
:
block_work_idx
[
I0
];
//
// set up Y dot dY
//
// S: blockwise gemm
auto
s_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
// TransposeC
auto
acc0_thread_origin
=
s_blockwise_gemm
.
CalculateCThreadOriginDataIndex8D
(
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{});
constexpr
auto
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
m0
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
m1
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
m2
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
constexpr
auto
ors_thread_desc_mblock_mrepeat_mwave_mperxdl
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
m0
,
m1
,
m2
));
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr
auto
p_block_lengths
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
P_M0
=
p_block_lengths
[
I0
];
// repeats
constexpr
auto
P_M1
=
p_block_lengths
[
I2
];
// waves
constexpr
auto
P_M2
=
p_block_lengths
[
I4
];
// xdl
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
));
constexpr
auto
y_thread_desc_m0_m1_o0_o1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_O
));
constexpr
auto
y_thread_cluster_desc
=
make_cluster_descriptor
(
Sequence
<
I1
,
YDotYGrad_M_O
::
ThreadClusterLength_M
,
...
...
@@ -311,6 +288,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const
auto
y_thread_data_on_block_idx
=
y_thread_cluster_idx
*
y_thread_desc_m0_m1_o0_o1
.
GetLengths
();
// if(get_thread_global_1d_id() == 1)
// {
// printf("y_thread_data_on_block_idx:{ %d, %d, %d,%d}, get_thread_local_1d_id: %d\n",
// y_thread_data_on_block_idx[I0],
// y_thread_data_on_block_idx[I1],
// y_thread_data_on_block_idx[I2],
// y_thread_data_on_block_idx[I3],
// get_thread_local_1d_id());
// }
const
auto
y_thread_data_on_grid_idx
=
make_multi_index
(
block_work_idx_m
,
I0
,
I0
/* all WGs start from o_block_idx = 0 */
,
I0
)
+
...
...
@@ -337,42 +324,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
auto
y_dot_ygrad_block_accum_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatGemmAcc
*>
(
p_shared
),
MPerBlock
);
constexpr
auto
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
=
make_naive_tensor_descriptor
(
make_tuple
(
I1
,
P_M0
,
P_M1
,
P_M2
),
make_tuple
(
P_M0
*
P_M1
*
P_M2
,
P_M1
*
P_M2
,
P_M2
,
I1
));
constexpr
auto
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
=
ors_thread_desc_mblock_mrepeat_mwave_mperxdl
;
// reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
auto
y_dot_ygrad_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
),
decltype
(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
),
Sequence
<
1
,
m0
,
m1
,
m2
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
m2
,
1
,
false
>
{
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
I0
,
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
])};
// mperxdl
auto
y_dot_ygrad_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
>
(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
.
GetElementSpaceSize
());
if
constexpr
(
Deterministic
)
{
block_sync_lds
();
}
//
// calculate Y dot dY
//
// clear accum buffers
y_dot_ygrad_thread_accum_buf
.
Clear
();
y_dot_ygrad_block_accum_buf
.
Clear
();
...
...
@@ -406,23 +362,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
oblock_idx
++
;
}
while
(
oblock_idx
<
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetLength
(
I2
));
// blockwise reduction using atomic_add
block_sync_lds
();
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
1
>
{}([
&
](
auto
iM
)
{
const
auto
idx_on_block
=
y_thread_data_on_block_idx
[
I1
]
+
iM
;
y_dot_ygrad_block_accum_buf
.
AtomicAdd
(
idx_on_block
,
true
,
y_dot_ygrad_thread_accum_buf
[
iM
]);
});
block_sync_lds
();
// distribute y_dot_ygrad to threads; LDS accum buffer can be safely reused after barrier
y_dot_ygrad_thread_copy_lds_to_vgpr
.
Run
(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
,
y_dot_ygrad_block_accum_buf
,
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
y_dot_ygrad_thread_buf
);
auto
ors_grid_desc_mblock_mrepeat_mwave_mperxdl
=
MakeORSGridDescriptor_MBlock_MRepeat_NWave_MPerXdl
(
ors_grid_desc_m
);
...
...
@@ -432,40 +371,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
decltype
(
ors_thread_desc_mblock_mrepeat_mwave_mperxdl
),
decltype
(
ors_grid_desc_mblock_mrepeat_mwave_mperxdl
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
1
,
1
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
Sequence
<
1
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
false
>
{
ors_grid_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
block_work_idx_m
,
// mblock
0
,
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
]),
// mperxdl
get_thread_local_1d_id
()),
// mperxdl
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
if
(
get_warp_local_1d_id
()
<
32
)
{
static_for
<
0
,
MXdlPerWave
,
1
>
{}([
&
](
auto
I
)
{
// copy from VGPR to Global
ors_thread_copy_vgpr_to_global
.
Run
(
ors_thread_desc_mblock_mrepeat_mwave_mperxdl
,
make_tuple
(
I0
,
Number
<
I
>
{},
I0
,
I0
),
y_dot_ygrad_thread_buf
,
make_tuple
(
I0
,
I0
),
y_dot_ygrad_thread_
accum_
buf
,
ors_grid_desc_mblock_mrepeat_mwave_mperxdl
,
ors_grid_buf
);
ors_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
ors_grid_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
0
,
1
,
0
,
0
));
});
}
// if(get_warp_local_1d_id() == 0)
// {
// printf(
// "acc0_thread_origin[I0]:%d acc0_thread_origin[I2]: %d acc0_thread_origin[I4]:
// %d\t", acc0_thread_origin[I0], acc0_thread_origin[I2], acc0_thread_origin[I4]);
// }
ignore
=
ors_thread_copy_vgpr_to_global
;
ignore
=
ors_grid_desc_mblock_mrepeat_mwave_mperxdl
;
}
};
...
...
include/ck/utility/get_id.hpp
View file @
118742b6
...
...
@@ -25,4 +25,6 @@ __device__ index_t get_grid_size() { return gridDim.x; }
__device__
index_t
get_block_size
()
{
return
blockDim
.
x
;
}
__device__
index_t
get_thread_local_1d_id_in_warp
()
{
return
threadIdx
.
x
%
get_warp_size
();
}
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment