Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6b814bac
Commit
6b814bac
authored
Jul 06, 2023
by
danyao12
Browse files
code cleanup for light kernels
parent
73f0c21b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
4 additions
and
329 deletions
+4
-329
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+0
-21
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+0
-21
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
+2
-167
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
+2
-120
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
6b814bac
...
...
@@ -97,7 +97,6 @@ template <typename GridwiseGemm,
typename
BGridDesc_BK0_N_BK1
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
typename
B1GridDesc_BK0_N_BK1
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
LSEGridDescriptor_M
,
typename
DGridDescriptor_M
,
typename
YGradGridDesc_O0_M_O1
,
...
...
@@ -131,8 +130,6 @@ __global__ void
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
DGridDescriptor_M
d_grid_desc_m
,
const
YGradGridDesc_O0_M_O1
ygrad_grid_desc_o0_m_o1
,
...
...
@@ -198,7 +195,6 @@ __global__ void
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
d_grid_desc_m
,
ygrad_grid_desc_o0_m_o1
,
...
...
@@ -233,7 +229,6 @@ __global__ void
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
d_grid_desc_m
,
ygrad_grid_desc_o0_m_o1
,
...
...
@@ -258,7 +253,6 @@ __global__ void
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
...
...
@@ -842,7 +836,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
c_gs_ms_gemm1ns_strides
)},
z_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
y_grid_desc_mblock_mperblock_oblock_operblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
d_block_2_ctile_map_
{
GridwiseYDotYGrad
::
MakeDefaultBlock2CTileMap
(
y_grid_desc_m_o_
)},
d_y_grid_desc_mblock_mperblock_oblock_operblock_
{},
...
...
@@ -882,16 +875,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
ignore
=
acc1_biases_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_biases_gs_ms_gemm1ns_strides
;
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b1_grid_desc_bk0_n_bk1_
,
y_grid_desc_m_o_
))
{
y_grid_desc_mblock_mperblock_oblock_operblock_
=
GridwiseGemm
::
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
(
y_grid_desc_m_o_
);
}
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
...
...
@@ -961,8 +944,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
;
...
...
@@ -1081,7 +1062,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
DeviceOp
::
LSEGridDesc_M
,
DeviceOp
::
DGridDesc_M
,
DeviceOp
::
YGradGridDesc_O0_M_O1
,
...
...
@@ -1116,7 +1096,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
lse_grid_desc_m_
,
arg
.
d_grid_desc_m_
,
arg
.
ygrad_grid_desc_o0_m_o1_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
6b814bac
...
...
@@ -96,7 +96,6 @@ template <typename GridwiseGemm,
typename
BGridDesc_BK0_N_BK1
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
typename
B1GridDesc_BK0_N_BK1
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
LSEGridDescriptor_M
,
typename
DGridDescriptor_M
,
typename
YGradGridDesc_M0_O_M1
,
...
...
@@ -130,8 +129,6 @@ __global__ void
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
DGridDescriptor_M
d_grid_desc_m
,
const
YGradGridDesc_M0_O_M1
ygrad_grid_desc_m0_o_m1
,
...
...
@@ -197,7 +194,6 @@ __global__ void
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
d_grid_desc_m
,
ygrad_grid_desc_m0_o_m1
,
...
...
@@ -232,7 +228,6 @@ __global__ void
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
d_grid_desc_m
,
ygrad_grid_desc_m0_o_m1
,
...
...
@@ -257,7 +252,6 @@ __global__ void
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
...
...
@@ -855,7 +849,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
c_gs_ms_gemm1ns_strides
)},
z_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
y_grid_desc_mblock_mperblock_oblock_operblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
d_block_2_ctile_map_
{
GridwiseYDotYGrad
::
MakeDefaultBlock2CTileMap
(
y_grid_desc_m_o_
)},
d_y_grid_desc_mblock_mperblock_oblock_operblock_
{},
...
...
@@ -895,16 +888,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
ignore
=
acc1_biases_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_biases_gs_ms_gemm1ns_strides
;
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b1_grid_desc_bk0_n_bk1_
,
y_grid_desc_m_o_
))
{
y_grid_desc_mblock_mperblock_oblock_operblock_
=
GridwiseGemm
::
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
(
y_grid_desc_m_o_
);
}
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
...
...
@@ -974,8 +957,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
;
...
...
@@ -1098,7 +1079,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
DeviceOp
::
LSEGridDesc_M
,
DeviceOp
::
DGridDesc_M
,
DeviceOp
::
YGradGridDesc_M0_O_M1
,
...
...
@@ -1133,7 +1113,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
lse_grid_desc_m_
,
arg
.
d_grid_desc_m_
,
arg
.
ygrad_grid_desc_m0_o_m1_
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
View file @
6b814bac
...
...
@@ -283,25 +283,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
return
true
;
}
__host__
__device__
static
constexpr
auto
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
(
const
YGridDesc_M_O
&
y_grid_desc_m_o
)
{
const
auto
M
=
y_grid_desc_m_o
.
GetLength
(
I0
);
const
auto
O
=
y_grid_desc_m_o
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
OBlock
=
O
/
Gemm1NPerBlock
;
const
auto
y_grid_desc_mblock_mperblock_oblock_operblock
=
transform_tensor_descriptor
(
y_grid_desc_m_o
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
OBlock
,
Number
<
Gemm1NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
y_grid_desc_mblock_mperblock_oblock_operblock
;
}
template
<
typename
SrcBlockwiseGemm
>
__host__
__device__
static
constexpr
auto
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
(
const
LSEGridDesc_M
&
lse_grid_desc_m
)
...
...
@@ -400,9 +381,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
k_grid_desc_n_k
);
}
using
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
=
remove_cvref_t
<
decltype
(
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
(
YGridDesc_M_O
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
KGridDesc_N_K
{}))
>
;
...
...
@@ -1107,53 +1085,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
true
>
;
};
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
{
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
InputDataType
);
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
static
constexpr
auto
ThreadSliceLength_O
=
Number
<
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadSliceLength_M
=
Number
<
BlockSliceLength_M_
*
ThreadClusterLength_O
/
BlockSize_
>
{};
// dY matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
ygrad_block_desc_o0_m_o1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
__host__
__device__
static
constexpr
auto
MakeYGradBlockDesc_M_O
()
{
const
auto
O0_
=
ygrad_block_desc_o0_m_o1
.
GetLength
(
I0
);
const
auto
M_
=
ygrad_block_desc_o0_m_o1
.
GetLength
(
I1
);
const
auto
O1_
=
ygrad_block_desc_o0_m_o1
.
GetLength
(
I2
);
static_assert
(
O0_
*
O1_
==
BlockSliceLength_O_
,
""
);
static_assert
(
M_
==
BlockSliceLength_M_
,
""
);
return
transform_tensor_descriptor
(
//(128, 64)
ygrad_block_desc_o0_m_o1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
O0_
,
O1_
)),
//(8, 8)
make_pass_through_transform
(
M_
)),
// 128
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
}
static
constexpr
auto
ygrad_block_desc_m_o
=
MakeYGradBlockDesc_M_O
();
static_assert
(
ThreadClusterLength_O
*
ThreadSliceLength_O
==
BlockSliceLength_O_
,
""
);
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
using
SrcBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
,
ThreadSliceLength_M
*
ThreadSliceLength_O
,
true
>
;
using
DstBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
,
ThreadSliceLength_M
,
true
>
;
};
using
YDotYGrad_M_O
=
YDotYGrad_M_O_
<
BlockSize
,
MPerBlock
,
Gemm1NPerBlock
>
;
struct
SharedMemTrait
{
// // LDS allocation for A and B: be careful of alignment
...
...
@@ -1248,8 +1179,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
VGridDesc_O0_N_O1
&
v_grid_desc_o0_n_o1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
y_grid_desc_mblock_mperblock_oblock_operblock
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
DGridDesc_M
&
d_grid_desc_m
,
const
YGradGridDesc_O0_M_O1
&
ygrad_grid_desc_o0_m_o1
,
...
...
@@ -1666,101 +1595,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
wave_m_n_id
[
I1
]),
// NPerXdl
tensor_operation
::
element_wise
::
PassThrough
{}};
//
// set up Y dot dY
//
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr
auto
p_block_lengths
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
();
constexpr
auto
P_M0
=
p_block_lengths
[
I0
];
// repeats
constexpr
auto
P_M1
=
p_block_lengths
[
I2
];
// waves
constexpr
auto
P_M2
=
p_block_lengths
[
I4
];
// xdl
constexpr
auto
P_M3
=
p_block_lengths
[
I5
];
constexpr
auto
P_M4
=
p_block_lengths
[
I6
];
constexpr
auto
y_thread_desc_m0_m1_o0_o1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_O
));
constexpr
auto
ygrad_thread_desc_m_o
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
YDotYGrad_M_O
::
ThreadSliceLength_M
,
YDotYGrad_M_O
::
ThreadSliceLength_O
));
constexpr
auto
y_thread_cluster_desc
=
make_cluster_descriptor
(
Sequence
<
I1
,
YDotYGrad_M_O
::
ThreadClusterLength_M
,
I1
,
YDotYGrad_M_O
::
ThreadClusterLength_O
>
{},
Sequence
<
0
,
1
,
2
,
3
>
{});
const
auto
y_thread_cluster_idx
=
y_thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
constexpr
auto
ygrad_thread_cluster_desc
=
make_cluster_descriptor
(
Sequence
<
YDotYGrad_M_O
::
ThreadClusterLength_M
,
YDotYGrad_M_O
::
ThreadClusterLength_O
>
{},
Sequence
<
0
,
1
>
{});
const
auto
ygrad_thread_cluster_idx
=
ygrad_thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
y_thread_data_on_block_idx
=
y_thread_cluster_idx
*
y_thread_desc_m0_m1_o0_o1
.
GetLengths
();
const
auto
ygrad_thread_data_on_block_idx
=
ygrad_thread_cluster_idx
*
ygrad_thread_desc_m_o
.
GetLengths
();
const
auto
y_thread_data_on_grid_idx
=
make_multi_index
(
num_gemm0_m_block_outer_loop
-
1
,
I0
,
I0
,
I0
)
+
y_thread_data_on_block_idx
;
// performs for y
auto
y_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
InputDataType
,
FloatGemmAcc
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
// SrcVectorDim
YDotYGrad_M_O
::
SrcScalarPerVector
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
true
/* ResetCoordAfterRun */
>
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_thread_data_on_grid_idx
);
// performs for ygrad
auto
ygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
GemmDataType
,
FloatGemmAcc
,
decltype
(
YDotYGrad_M_O
::
ygrad_block_desc_m_o
),
decltype
(
ygrad_thread_desc_m_o
),
decltype
(
ygrad_thread_desc_m_o
.
GetLengths
()),
Sequence
<
0
,
1
>
,
1
,
// SrcVectorDim
YDotYGrad_M_O
::
SrcScalarPerVector
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
true
/* ResetCoordAfterRun */
>
(
YDotYGrad_M_O
::
ygrad_block_desc_m_o
,
ygrad_thread_data_on_block_idx
);
constexpr
auto
y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
P_M0
,
P_M1
,
P_M2
,
P_M3
,
P_M4
));
constexpr
auto
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4
=
lse_thread_desc_mb_m0_m1_m2_m3_m4
;
// reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
auto
y_dot_ygrad_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4
),
decltype
(
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4
),
Sequence
<
1
,
m0
,
m1
,
m2
,
m3
,
m4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
1
,
1
,
true
/* ResetCoordAfterRun */
>
{
y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4
,
make_multi_index
(
I0
,
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
],
// mperxdl
acc0_thread_origin
[
I5
],
acc0_thread_origin
[
I6
])};
auto
y_dot_ygrad_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
>
(
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4
.
GetElementSpaceSize
());
...
...
@@ -1809,7 +1648,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
block_sync_lds
();
//
//
calculate Y dot dY
//
load d and lse
//
lse_thread_copy_global_to_vgpr
.
Run
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
...
...
@@ -2071,6 +1910,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
sgrad_slice_idx
[
I3
],
sgrad_slice_idx
[
I3
]
+
Gemm2Params
::
ABlockSliceLengths_M0_K0_M1_K1
::
At
(
I3
));
block_sync_lds
();
// sync before write
if
(
gemm2_a_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
{
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds
.
Run
(
...
...
@@ -2082,8 +1922,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
gemm2_a_block_buf
);
}
// block_sync_lds(); // sync before write
qgrad_gemm_tile_k_blockwise_copy
.
Run
(
Gemm2
::
b_block_desc_n0_n1_n2_k0_k1_k2_k3
,
k_block_buf
,
Gemm2
::
b_thread_desc_n0_n1_n2_k0_k1_k2_k3
,
...
...
@@ -2127,9 +1965,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
lse_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
));
y_threadwise_copy
.
MoveSrcSliceWindow
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
make_multi_index
(
-
1
,
0
,
0
,
0
));
}
while
(
0
<
gemm0_m_block_outer_index
--
);
// end j loop
// shuffle dK&dV and write
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
View file @
6b814bac
...
...
@@ -338,25 +338,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
__host__
__device__
static
constexpr
auto
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
(
const
YGridDesc_M_O
&
y_grid_desc_m_o
)
{
const
auto
M
=
y_grid_desc_m_o
.
GetLength
(
I0
);
const
auto
O
=
y_grid_desc_m_o
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
OBlock
=
O
/
Gemm1NPerBlock
;
const
auto
y_grid_desc_mblock_mperblock_oblock_operblock
=
transform_tensor_descriptor
(
y_grid_desc_m_o
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
OBlock
,
Number
<
Gemm1NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
y_grid_desc_mblock_mperblock_oblock_operblock
;
}
template
<
typename
SrcBlockwiseGemm
>
__host__
__device__
static
constexpr
auto
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
(
const
LSEGridDesc_M
&
lse_grid_desc_m
)
...
...
@@ -455,9 +436,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
k_grid_desc_n_k
);
}
using
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
=
remove_cvref_t
<
decltype
(
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
(
YGridDesc_M_O
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
KGridDesc_N_K
{}))
>
;
...
...
@@ -966,30 +944,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
true
>
;
};
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
{
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
InputDataType
);
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
static
constexpr
auto
ThreadSliceLength_O
=
Number
<
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadSliceLength_M
=
Number
<
BlockSliceLength_M_
*
ThreadClusterLength_O
/
BlockSize_
>
{};
static_assert
(
ThreadClusterLength_O
*
ThreadSliceLength_O
==
BlockSliceLength_O_
,
""
);
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
using
SrcBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
,
ThreadSliceLength_M
*
ThreadSliceLength_O
,
true
>
;
using
DstBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
,
ThreadSliceLength_M
,
true
>
;
};
using
YDotYGrad_M_O
=
YDotYGrad_M_O_
<
BlockSize
,
MPerBlock
,
Gemm1NPerBlock
>
;
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major)
struct
PGradGemmTile_M_N_O
{
...
...
@@ -1180,8 +1134,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
VGridDesc_N0_O_N1
&
v_grid_desc_n0_o_n1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
y_grid_desc_mblock_mperblock_oblock_operblock
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
DGridDesc_M
&
d_grid_desc_m
,
const
YGradGridDesc_M0_O_M1
&
ygrad_grid_desc_m0_o_m1
,
...
...
@@ -1627,78 +1579,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
wave_m_n_id
[
I1
]),
// NPerXdl
tensor_operation
::
element_wise
::
PassThrough
{}};
//
// set up Y dot dY
//
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr
auto
p_block_lengths
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
();
constexpr
auto
P_M0
=
p_block_lengths
[
I0
];
// repeats
constexpr
auto
P_M1
=
p_block_lengths
[
I2
];
// waves
constexpr
auto
P_M2
=
p_block_lengths
[
I4
];
// xdl
constexpr
auto
P_M3
=
p_block_lengths
[
I5
];
constexpr
auto
P_M4
=
p_block_lengths
[
I6
];
constexpr
auto
y_thread_desc_m0_m1_o0_o1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_O
));
constexpr
auto
y_thread_cluster_desc
=
make_cluster_descriptor
(
Sequence
<
I1
,
YDotYGrad_M_O
::
ThreadClusterLength_M
,
I1
,
YDotYGrad_M_O
::
ThreadClusterLength_O
>
{},
Sequence
<
0
,
1
,
2
,
3
>
{});
const
auto
y_thread_cluster_idx
=
y_thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
y_thread_data_on_block_idx
=
y_thread_cluster_idx
*
y_thread_desc_m0_m1_o0_o1
.
GetLengths
();
const
auto
y_thread_data_on_grid_idx
=
make_multi_index
(
num_gemm0_m_block_outer_loop
-
1
,
I0
,
I0
,
I0
)
+
y_thread_data_on_block_idx
;
// performs double duty for both y and ygrad
auto
yygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
InputDataType
,
FloatGemmAcc
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
// SrcVectorDim
YDotYGrad_M_O
::
SrcScalarPerVector
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
true
/* ResetCoordAfterRun */
,
false
/* InvalidElementAsNaN */
>
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_thread_data_on_grid_idx
);
constexpr
auto
y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
P_M0
,
P_M1
,
P_M2
,
P_M3
,
P_M4
));
constexpr
auto
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4
=
lse_thread_desc_mb_m0_m1_m2_m3_m4
;
// reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
auto
y_dot_ygrad_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4
),
decltype
(
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4
),
Sequence
<
1
,
m0
,
m1
,
m2
,
m3
,
m4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
1
,
1
,
true
/* ResetCoordAfterRun */
>
{
y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4
,
make_multi_index
(
I0
,
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
],
// mperxdl
acc0_thread_origin
[
I5
],
acc0_thread_origin
[
I6
])};
auto
y_dot_ygrad_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
>
(
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4
.
GetElementSpaceSize
());
...
...
@@ -1724,7 +1609,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
}
//
//
calculate Y dot dY
//
load d and lse
//
lse_thread_copy_global_to_vgpr
.
Run
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
...
...
@@ -2002,6 +1887,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
sgrad_slice_idx
[
I3
],
sgrad_slice_idx
[
I3
]
+
Gemm2Params
::
ABlockSliceLengths_M0_K0_M1_K1
::
At
(
I3
));
block_sync_lds
();
// sync before write
if
(
gemm2_a_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
{
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds
.
Run
(
...
...
@@ -2018,7 +1904,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
qgrad_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_n0_k_n1
,
Gemm2
::
b_block_slice_copy_step
);
block_sync_lds
();
// sync before write
qgrad_gemm_tile_k_blockwise_copy
.
RunWrite
(
Gemm2
::
b_block_desc_k0_n_k1
,
gemm2_b_block_buf
);
...
...
@@ -2120,9 +2005,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step M
lse_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
));
yygrad_threadwise_copy
.
MoveSrcSliceWindow
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
make_multi_index
(
-
1
,
0
,
0
,
0
));
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment