Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
21ef37b4
Unverified
Commit
21ef37b4
authored
Sep 11, 2023
by
Dan Yao
Committed by
GitHub
Sep 11, 2023
Browse files
Merge pull request #889 from ROCmSoftwarePlatform/mha-train-develop-bwdopt-bias
Mha train develop bwdopt bias
parents
1f04cd2b
db579ac9
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
729 additions
and
393 deletions
+729
-393
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+722
-390
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp
...gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp
+7
-3
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
21ef37b4
...
...
@@ -39,7 +39,7 @@ template <typename InputDataType,
typename
KGridDesc_N_K
,
typename
D0GridDesc_M_N
,
typename
ZGridDesc_M_N
,
typename
VGridDesc_
N0_O_N
1
,
typename
VGridDesc_
O0_N_O
1
,
typename
YGridDesc_M_O
,
typename
LSEGridDesc_M
,
index_t
NumGemmKPrefetchStage
,
...
...
@@ -49,6 +49,7 @@ template <typename InputDataType,
index_t
KPerBlock
,
index_t
Gemm1NPerBlock
,
index_t
Gemm1KPerBlock
,
index_t
Gemm2KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
B1K1Value
,
...
...
@@ -124,6 +125,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
K_K0
=
Number
<
Gemm1NPerBlock
/
BK1Value
>
{};
static
constexpr
auto
V_K3
=
BK1
;
static
constexpr
auto
V_K2
=
mfma
.
num_input_blks
;
static
constexpr
auto
V_K1
=
KPerBlock
/
V_K2
/
V_K3
;
static
constexpr
auto
V_K0
=
Gemm1NPerBlock
/
KPerBlock
;
static
constexpr
auto
V_N1
=
NXdlPerWave
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_8x16() generates 8 random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
// 16
...
...
@@ -197,6 +204,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1
()
{
// K matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
K_K0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetVThreadDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
()
{
// V matrix in Vgpr, dst of threadwise copy
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
V_K1
>
{},
I1
,
I1
,
Number
<
V_N1
>
{},
I1
,
I1
,
Number
<
V_K3
>
{}));
}
template
<
typename
AccThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
__host__
__device__
static
constexpr
auto
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1
(
const
AccThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
)
...
...
@@ -277,36 +299,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
VGridDesc_
N0_O_N
1
&
v_grid_desc_
n0_o_n
1
,
const
VGridDesc_
O0_N_O
1
&
v_grid_desc_
o0_n_o
1
,
const
YGridDesc_M_O
&
y_grid_desc_m_o
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
v_grid_desc_
n0
_o_n1
.
GetLength
(
I
1
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
O
=
v_grid_desc_
o0_n_o1
.
GetLength
(
I0
)
*
v_grid_desc
_o
0
_n
_o
1
.
GetLength
(
I
2
);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if
(
Gemm1N
!=
K
)
if
(
O
!=
K
)
{
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
}
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
Gemm1N
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
if
(
!
(
M
==
y_grid_desc_m_o
.
GetLength
(
I0
)
&&
O
==
y_grid_desc_m_o
.
GetLength
(
I1
)))
{
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
O
%
Gemm1NPerBlock
==
0
))
{
return
false
;
}
...
...
@@ -411,22 +433,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}
__device__
static
auto
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
const
VGridDesc_
N0_O_N
1
&
v_grid_desc_
n0_o_n
1
)
const
VGridDesc_
O0_N_O
1
&
v_grid_desc_
o0_n_o
1
)
{
const
auto
N
0
=
v_grid_desc_
n0_o_n
1
.
GetLength
(
I0
);
const
auto
O
=
v_grid_desc_
n0_o_n
1
.
GetLength
(
I1
);
const
auto
N
1
=
v_grid_desc_
n0_o_n
1
.
GetLength
(
I2
);
const
auto
N
=
N
0
*
N
1
;
const
auto
O
0
=
v_grid_desc_
o0_n_o
1
.
GetLength
(
I0
);
const
auto
N
=
v_grid_desc_
o0_n_o
1
.
GetLength
(
I1
);
const
auto
O
1
=
v_grid_desc_
o0_n_o
1
.
GetLength
(
I2
);
const
auto
O
=
O
0
*
O
1
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
OBlock
=
O
/
Gemm1NPerBlock
;
const
auto
v_grid_desc_n_o
=
transform_tensor_descriptor
(
v_grid_desc_
n0_o_n
1
,
make_tuple
(
make_pass_through_transform
(
O
),
make_merge_transform_v3_division_mod
(
make_tuple
(
N
0
,
N
1
))),
v_grid_desc_
o0_n_o
1
,
make_tuple
(
make_pass_through_transform
(
N
),
make_merge_transform_v3_division_mod
(
make_tuple
(
O
0
,
O
1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
v_grid_desc_n_o
,
...
...
@@ -438,14 +460,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__device__
static
auto
MakeQGradGridDesc_M_K
(
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
)
{
const
auto
K
_K
0
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
K
_K
1
=
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
K0
_
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
M
_
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
K1
_
=
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
return
transform_tensor_descriptor
(
q_grid_desc_k0_m_k1
,
make_tuple
(
make_pass_through_transform
(
M
),
make_merge_transform_v3_division_mod
(
make_tuple
(
K
_K0
,
K_K1
))),
make_tuple
(
make_pass_through_transform
(
M
_
),
make_merge_transform_v3_division_mod
(
make_tuple
(
K
0_
,
K1_
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
...
...
@@ -467,16 +489,120 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
ZGridDesc_M_N
{}))
>
;
// S / dP Gemm (type 1 rcc)
// K / V
struct
GemmBlockwiseCopy
{
__device__
static
auto
MakeVGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
(
const
VGridDesc_O0_N_O1
&
v_grid_desc_o0_n_o1
)
{
const
auto
K0_
=
v_grid_desc_o0_n_o1
.
GetLength
(
I0
);
const
auto
N_
=
v_grid_desc_o0_n_o1
.
GetLength
(
I1
);
const
auto
K1_
=
v_grid_desc_o0_n_o1
.
GetLength
(
I2
);
constexpr
auto
V_N3
=
NPerXdl
;
constexpr
auto
V_N2
=
Gemm0NWaves
;
const
auto
V_N0
=
N_
/
NPerBlock
;
const
auto
v_grid_desc_n_k
=
transform_tensor_descriptor
(
v_grid_desc_o0_n_o1
,
make_tuple
(
make_pass_through_transform
(
N_
),
make_merge_transform_v3_division_mod
(
make_tuple
(
K0_
,
K1_
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
v_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
V_N0
,
V_N1
,
V_N2
,
V_N3
)),
make_unmerge_transform
(
make_tuple
(
V_K0
,
V_K1
,
V_K2
,
V_K3
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
,
4
,
5
,
6
>
{},
Sequence
<
0
,
1
,
2
,
7
>
{}));
}
// K matrix in LDS, dst of blockwise copy
static
constexpr
auto
k_block_desc_k0_n_k1
=
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1
();
// V matrix in Vgpr, dst of threadwise copy
static
constexpr
auto
v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
=
GetVThreadDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
();
template
<
typename
GridDesc_K0_N_K1
>
using
KBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K_K0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
InputDataType
,
GemmDataType
,
GridDesc_K0_N_K1
,
decltype
(
k_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
template
<
typename
GridDesc_K0_K1_k2_N0_N1_N2_N3_K3
>
using
VBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v2
<
InputDataType
,
GemmDataType
,
GridDesc_K0_K1_k2_N0_N1_N2_N3_K3
,
decltype
(
v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
),
decltype
(
v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
BK1
,
1
,
true
/* ResetCoordAfterRun */
>
;
static
constexpr
auto
VBlockBufferSize
=
V_K0
;
static
constexpr
auto
v_block_slice_copy_step
=
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
};
// dP Gemm (type 1 rcc, B in Vgpr)
template
<
typename
BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
>
struct
Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
template
<
typename
BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
>
__host__
__device__
static
constexpr
auto
GetBThreadDescriptor_K0_N_K1
(
const
BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
&
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
)
{
// b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 to b_thread_desc_k0_n_k1
// k0_k1_k2 -> k0
// n0_n1_n2_n3 -> n
// k3 -> k1
const
auto
k0
=
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetLength
(
I0
);
const
auto
k1
=
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetLength
(
I1
);
const
auto
k2
=
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetLength
(
I2
);
const
auto
n0
=
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetLength
(
I3
);
const
auto
n1
=
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetLength
(
I4
);
const
auto
n2
=
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetLength
(
I5
);
const
auto
n3
=
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetLength
(
I6
);
const
auto
k3
=
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetLength
(
I7
);
return
transform_tensor_descriptor
(
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
k0
,
k1
,
k2
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
n0
,
n1
,
n2
,
n3
)),
make_pass_through_transform
(
k3
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
,
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
}
static
constexpr
auto
b_src_thread_desc_k0_n_k1
=
GetBThreadDescriptor_K0_N_K1
(
BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3
{});
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
...
...
@@ -492,9 +618,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__
__device__
static
constexpr
auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
NXdlPerWave
,
1
,
1
>
(
BBlockDesc_BK0_N_BK1
{});
}
...
...
@@ -523,31 +647,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
template
<
typename
GridDesc_K0_N_K1
>
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
InputDataType
,
GemmDataType
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
...
...
@@ -556,9 +655,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
GemmDataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_
block
_desc_
b
k0_n_
b
k1
),
decltype
(
b_
src_thread
_desc_k0_n_k1
),
decltype
(
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_ak0_m_ak1
)),
decltype
(
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
b_
block
_desc_
b
k0_n_
b
k1
)),
decltype
(
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
b_
src_thread
_desc_k0_n_k1
)),
MPerBlock
,
NPerBlock
,
KPerBlock
,
...
...
@@ -566,10 +665,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
,
false
,
KPack
*
XdlopsGemm
<
GemmDataType
,
MPerXdl
,
NPerXdl
,
KPack
,
false
>
{}.
K0PerXdlops
,
KPack
>
;
static
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
};
// dV / dK Gemm (type 2 rrr)
...
...
@@ -707,48 +808,41 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dQ Gemm (type 3 crr)
// Describes tuning parameter for C2_m_n = A2_m_k * B2_k_n
template
<
index_t
Sum_K_
=
NPerXdl
*
2
>
struct
Gemm2Params_
struct
Gemm2Params
{
static
constexpr
index_t
Gemm2_M
=
MPerBlock
;
static
constexpr
index_t
Gemm2_K
=
NPerBlock
;
static
constexpr
index_t
Gemm2_N
=
Gemm1NPerBlock
;
static
constexpr
index_t
Sum_K
=
Sum_K_
;
static
constexpr
index_t
Gemm2_M
=
MPerBlock
;
// 64
static
constexpr
index_t
Gemm2_K
=
NPerBlock
;
// 128
static
constexpr
index_t
Gemm2_N
=
Gemm1NPerBlock
;
// 128
static
constexpr
index_t
Sum_K
=
Gemm2KPerBlock
;
static
constexpr
index_t
A_K1
=
8
;
//
P
will be row-major
static
constexpr
index_t
A_K1
=
8
;
//
dS
will be row-major
static
constexpr
index_t
A_K0
=
Sum_K
/
A_K1
;
static
constexpr
index_t
A_LdsPad
=
0
;
// how many multiples of K1 per M * K1 elements
static
constexpr
index_t
B_K1
=
B1K1
;
// dY assumed row-major, typically =2 for fp16
static
constexpr
index_t
B_K0
=
Sum_K
/
B_K1
;
static
constexpr
index_t
B_LdsPad
=
0
;
// how many multiples of K1 per N * K1 elements
static_assert
(
Sum_K
%
NPerXdl
==
0
,
""
);
static
constexpr
index_t
BSrcVectorDim
=
1
;
// Gemm2_N dimension
static
constexpr
index_t
BSrcScalarPerVector
=
4
;
static
constexpr
index_t
GemmNWave
=
Gemm2_N
/
Gemm2NXdlPerWave
/
NPerXdl
;
static
constexpr
index_t
GemmMWave
=
BlockSize
/
get_warp_size
()
/
GemmNWave
;
static
constexpr
index_t
GemmNRepeat
=
Gemm2NXdlPerWave
;
static
constexpr
index_t
GemmMRepeat
=
Gemm2_M
/
GemmMWave
/
MPerXdl
;
static
constexpr
index_t
GemmKPack
=
math
::
max
(
math
::
lcm
(
A_K1
,
B_K1
),
mfma
.
k_per_blk
);
using
BBlockSliceLengths
=
Sequence
<
B_K0
,
Gemm2_N
,
B_K1
>
;
using
BThreadClusterLengths
=
Sequence
<
BlockSize
/
(
Gemm2_N
/
BSrcScalarPerVector
),
Gemm2_N
/
BSrcScalarPerVector
,
1
>
;
using
BThreadClusterArrangeOrder
=
Sequence
<
0
,
2
,
1
>
;
static
constexpr
index_t
GemmNWave
=
Gemm2_N
/
Gemm2NXdlPerWave
/
NPerXdl
;
// 1 // 2
static
constexpr
index_t
GemmMWave
=
BlockSize
/
get_warp_size
()
/
GemmNWave
;
// 4 // 2
static
constexpr
index_t
GemmNRepeat
=
Gemm2NXdlPerWave
;
// 1 // 1
static
constexpr
index_t
GemmMRepeat
=
Gemm2_M
/
GemmMWave
/
MPerXdl
;
// 1 // 1
static
constexpr
index_t
GemmKLoop
=
Gemm2_K
/
Sum_K
;
// 2 // 2
static
constexpr
index_t
GemmKPack
=
math
::
max
(
A_K1
,
mfma
.
k_per_blk
);
static
constexpr
index_t
B_K3
=
GemmKPack
;
// 8
static
constexpr
index_t
B_K2
=
XdlopsGemm
<
GemmDataType
,
MPerXdl
,
NPerXdl
,
GemmKPack
,
false
>
{}.
K0PerXdlops
;
// 2
static
constexpr
index_t
B_K1
=
Sum_K
/
B_K2
/
B_K3
;
// 4
static
constexpr
index_t
B_K0
=
GemmKLoop
;
// 2
__host__
__device__
static
constexpr
auto
GetABlockSliceLengths_M0_K0_M1_K1_M2_K2
()
{
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr
index_t
k
=
Gemm2Params
::
Sum_K
-
1
;
constexpr
index_t
k
=
Sum_K
-
1
;
constexpr
index_t
k2
=
k
%
NPerXdl
;
constexpr
index_t
k1
=
k
/
NPerXdl
%
Gemm0NWaves
;
constexpr
index_t
k0
=
k
/
NPerXdl
/
Gemm0NWaves
%
NXdlPerWave
;
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr
index_t
m
=
Gemm2Params
::
Gemm2_M
-
1
;
constexpr
index_t
m
=
Gemm2_M
-
1
;
constexpr
index_t
m2
=
m
%
MPerXdl
;
constexpr
index_t
m1
=
m
/
MPerXdl
%
Gemm0MWaves
;
constexpr
index_t
m0
=
m
/
MPerXdl
/
Gemm0MWaves
%
MXdlPerWave
;
...
...
@@ -769,10 +863,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using
ABlockSliceLengths_M0_K0_M1_K1
=
decltype
(
GetABlockSliceLengths_M0_K0_M1_K1
());
//(2, 1, 1, 2) //(4, 1, 1, 2)
};
using
Gemm2Params
=
Gemm2Params_
<>
;
// tune later
// dQ Gemm (type 3 crr)
template
<
typename
Gemm2Params
,
typename
ASrcBlockwiseGemm
>
template
<
typename
Gemm2Params
,
typename
ASrcBlockwiseGemm
,
typename
BSrcBlockDesc_N0_K_N1
>
struct
Gemm2
{
private:
...
...
@@ -795,8 +888,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_k0_m_k1
=
GetA2BlockDescriptor_K0_M_K1
<
Gemm2Params
>
();
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_k0_n_k1
=
GetB2BlockDescriptor_K0_N_K1
<
Gemm2Params
>
();
template
<
typename
ABlockDesc_K0_M_K1
>
__host__
__device__
static
constexpr
auto
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_K0_M_K1
&
)
{
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
Gemm2Params
::
GemmMRepeat
,
Gemm2Params
::
GemmMWave
,
MPerXdl
>
(
ABlockDesc_K0_M_K1
{});
}
template
<
typename
BBlockDesc_K0_N_K1
>
__host__
__device__
static
constexpr
auto
MakeGemm2BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_K0_N_K1
&
)
{
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
Gemm2Params
::
GemmNRepeat
,
1
,
1
>
(
BBlockDesc_K0_N_K1
{});
}
__host__
__device__
static
constexpr
auto
MakeABlockDesc_M0_K0_M1_K1_M2_M3_M4_K2
()
{
...
...
@@ -875,49 +982,112 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
1
,
// DstScalarStrideInVector
true
>
;
template
<
typename
GridDesc_K0_N_K1
>
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
typename
Gemm2Params
::
BBlockSliceLengths
,
typename
Gemm2Params
::
BThreadClusterLengths
,
typename
Gemm2Params
::
BThreadClusterArrangeOrder
,
InputDataType
,
__host__
__device__
static
constexpr
auto
MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3
()
{
const
auto
N0_
=
BSrcBlockDesc_N0_K_N1
{}.
GetLength
(
I0
);
const
auto
K_
=
BSrcBlockDesc_N0_K_N1
{}.
GetLength
(
I1
);
const
auto
N1_
=
BSrcBlockDesc_N0_K_N1
{}.
GetLength
(
I2
);
constexpr
auto
b_block_desc_n_k
=
transform_tensor_descriptor
(
//(32, 128) //(64, 128)
BSrcBlockDesc_N0_K_N1
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
N0_
,
N1_
)),
//(4, 8) //(8, 8)
make_pass_through_transform
(
K_
)),
// 128 // 128
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
b_block_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Gemm2Params
::
GemmNRepeat
,
Gemm2Params
::
GemmNWave
,
NPerXdl
)),
//(1, 1, 32) //(1, 2, 32)
make_unmerge_transform
(
make_tuple
(
Gemm2Params
::
B_K0
,
Gemm2Params
::
B_K1
,
Gemm2Params
::
B_K2
,
Gemm2Params
::
B_K3
))),
//(2, 4, 2, 8) //(2, 4, 2, 8)
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
,
6
>
{}));
}
static
constexpr
auto
b_block_desc_n0_n1_n2_k0_k1_k2_k3
=
MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3
();
using
BThreadSlice_N0_N1_N2_K0_K1_K2_K3
=
Sequence
<
Gemm2Params
::
GemmNRepeat
,
1
,
1
,
1
,
Gemm2Params
::
B_K1
,
1
,
Gemm2Params
::
B_K3
>
;
static
constexpr
auto
b_thread_desc_n0_n1_n2_k0_k1_k2_k3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
Gemm2Params
::
GemmNRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
Gemm2Params
::
B_K1
>
{},
I1
,
Number
<
Gemm2Params
::
B_K3
>
{}));
__host__
__device__
static
constexpr
auto
MakeBThreadDesc_K0_N_K1
()
{
constexpr
auto
b_thread_desc_n_k
=
transform_tensor_descriptor
(
b_thread_desc_n0_n1_n2_k0_k1_k2_k3
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
Gemm2Params
::
GemmNRepeat
>
{},
I1
,
I1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
I1
,
Number
<
Gemm2Params
::
B_K1
>
{},
I1
,
Number
<
Gemm2Params
::
B_K3
>
{}))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
,
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
b_thread_desc_n_k
,
make_tuple
(
make_pass_through_transform
(
Number
<
Gemm2Params
::
GemmNRepeat
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
Gemm2Params
::
B_K1
>
{},
Number
<
Gemm2Params
::
B_K3
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
}
static
constexpr
auto
b_thread_desc_k0_n_k1
=
MakeBThreadDesc_K0_N_K1
();
using
BBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v2
<
GemmDataType
,
GemmDataType
,
decltype
(
b_block_desc_n0_n1_n2_k0_k1_k2_k3
),
decltype
(
b_thread_desc_n0_n1_n2_k0_k1_k2_k3
),
BThreadSlice_N0_N1_N2_K0_K1_K2_K3
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
6
,
1
,
1
,
true
>
;
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
);
static
constexpr
auto
b_block_reset_copy_step
=
make_multi_index
(
0
,
0
,
0
,
-
Gemm2Params
::
B_K0
,
0
,
0
,
0
);
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
GemmDataType
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc_k0_n_k1
),
typename
Gemm2Params
::
BThreadClusterArrangeOrder
,
// access order == thread order
Sequence
<
1
,
0
,
2
>
,
Gemm2Params
::
BSrcVectorDim
,
2
,
// DstVectorDim
Gemm2Params
::
BSrcScalarPerVector
,
Gemm2Params
::
B_K1
,
1
,
1
,
true
,
true
,
1
>
;
FloatGemmAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_thread_desc_k0_n_k1
),
decltype
(
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_k0_m_k1
)),
decltype
(
MakeGemm2BMmaTileDescriptor_N0_N1_N2_K
(
b_thread_desc_k0_n_k1
)),
MPerBlock
,
Gemm1NPerBlock
,
Gemm2Params
::
Sum_K
,
MPerXdl
,
NPerXdl
,
Gemm2Params
::
GemmMRepeat
,
Gemm2Params
::
GemmNRepeat
,
Gemm2Params
::
GemmKPack
,
true
,
// TransposeC
Gemm2Params
::
GemmKPack
*
XdlopsGemm
<
GemmDataType
,
MPerXdl
,
NPerXdl
,
Gemm2Params
::
GemmKPack
,
false
>
{}
.
K0PerXdlops
,
Gemm2Params
::
GemmKPack
>
;
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
GemmDataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
MPerXdl
,
NPerXdl
,
Gemm2Params
::
GemmMRepeat
,
Gemm2Params
::
GemmNRepeat
,
Gemm2Params
::
GemmKPack
,
true
>
;
// TranspossC
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
Gemm2Params
::
B_K0
,
0
,
0
);
static
constexpr
auto
c_block_slice_copy_step
=
make_multi_index
(
-
Gemm2Params
::
GemmMRepeat
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
static
constexpr
auto
b_block_reset_copy_step
=
make_multi_index
(
-
NPerBlock
/
Gemm2Params
::
B_K1
,
0
,
0
);
template
<
typename
CGradDesc_M_N
>
__host__
__device__
static
auto
...
...
@@ -964,6 +1134,84 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true
>
;
};
// S Gemm (type 4 rcc, B in LDS)
template
<
typename
BSrcBlockDesc_K0_N_K1
>
struct
Gemm3
{
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
MXdlPerWave
,
MWaves
,
MPerXdl
>
(
ABlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeGemm3BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
template
<
typename
GridDesc_K0_M_K1
>
using
ABlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
InputDataType
,
GemmDataType
,
GridDesc_K0_M_K1
,
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
GemmDataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
BSrcBlockDesc_K0_N_K1
,
decltype
(
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_ak0_m_ak1
)),
decltype
(
MakeGemm3BMmaTileDescriptor_N0_N1_N2_K
(
BSrcBlockDesc_K0_N_K1
{})),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
>
;
static
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
0
,
0
,
KPerBlock
);
static
constexpr
auto
b_block_reset_copy_step
=
make_multi_index
(
0
,
0
,
0
,
-
Gemm1NPerBlock
);
};
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
{
...
...
@@ -1014,26 +1262,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return
ygrad_grid_desc_o0_m_o1
;
}
template
<
typename
VGridDesc_N0_O_N1_
>
__device__
static
auto
MakeVGridDesc_O0_N_O1
(
const
VGridDesc_N0_O_N1_
&
v_grid_desc_n0_o_n1
)
{
const
auto
N0
=
v_grid_desc_n0_o_n1
.
GetLength
(
I0
);
const
auto
O
=
v_grid_desc_n0_o_n1
.
GetLength
(
I1
);
const
auto
N1
=
v_grid_desc_n0_o_n1
.
GetLength
(
I2
);
constexpr
auto
V_O1
=
BK1
;
const
auto
V_O0
=
O
/
V_O1
;
const
auto
v_grid_desc_o0_n_o1
=
transform_tensor_descriptor
(
v_grid_desc_n0_o_n1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
V_O0
,
V_O1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
N0
,
N1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
v_grid_desc_o0_n_o1
;
}
};
// QGrad Gemm has the same layout as Y = P * V Gemm (A in acc B row-major)
...
...
@@ -1042,17 +1270,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
template
<
typename
KGridDesc_K0_N_K1_
>
__device__
static
auto
MakeKGridDesc_N0_K_N1
(
const
KGridDesc_K0_N_K1_
&
k_grid_desc_k0_n_k1
)
{
const
auto
K
_K
0
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
_K
1
=
k_grid_desc_k0_n_k1
.
GetLength
(
I2
);
const
auto
K0
_
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
_
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K1
_
=
k_grid_desc_k0_n_k1
.
GetLength
(
I2
);
constexpr
auto
K_
N1
=
B1K1
;
const
auto
K_
N0
=
N
/
K_
N1
;
constexpr
auto
N1
_
=
B1K1
;
const
auto
N0
_
=
N
_
/
N1
_
;
const
auto
k_grid_desc_n0_k_n1
=
transform_tensor_descriptor
(
k_grid_desc_k0_n_k1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K_
N0
,
K_
N1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
K
_K0
,
K_K1
))),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
N0
_
,
N1
_
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
K
0_
,
K1_
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
...
@@ -1084,75 +1312,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}
};
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
a2_block_desc_k0_m_k1
=
GetA2BlockDescriptor_K0_M_K1
<
Gemm2Params
>
();
static
constexpr
auto
b2_block_desc_k0_n_k1
=
GetB2BlockDescriptor_K0_N_K1
<
Gemm2Params
>
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
GemmDataType
)
>
{};
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a2_block_space_size_aligned
=
math
::
integer_least_multiple
(
a2_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b2_block_space_size_aligned
=
math
::
integer_least_multiple
(
b2_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
static
constexpr
auto
a2_block_space_offset
=
0
;
static
constexpr
auto
b2_block_space_offset
=
a2_block_space_size_aligned
.
value
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
reduction_space_offset
=
0
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
+
SharedMemTrait
::
b_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
gemm2_bytes_end
=
(
SharedMemTrait
::
a2_block_space_size_aligned
+
SharedMemTrait
::
b2_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatGemmAcc
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
gemm2_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
}
// D0
static
constexpr
auto
D0M2
=
Number
<
4
>
{};
static
constexpr
auto
D0M1
=
Number
<
MPerXdl
>
{}
/
D0M2
;
...
...
@@ -1185,12 +1344,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
template
<
typename
DataType
>
struct
TypeTransform
{
using
Type
=
DataType
;
using
Type
=
DataType
;
static
constexpr
index_t
Size0
=
sizeof
(
DataType
);
static
constexpr
index_t
Size
=
sizeof
(
DataType
);
};
template
<
>
struct
TypeTransform
<
void
>
{
using
Type
=
ck
::
half_t
;
using
Type
=
ck
::
half_t
;
static
constexpr
index_t
Size0
=
0
;
static
constexpr
index_t
Size
=
sizeof
(
ck
::
half_t
);
};
static
constexpr
index_t
NThreadClusterLengths
=
32
;
static_assert
(
NPerXdl
==
32
);
...
...
@@ -1254,7 +1417,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true
,
// DstResetCoord
1
>
;
using
D0ThreadCopy
=
using
D0Thread
Wise
Copy
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_read_desc_n0_n1_m0_m1_m2
),
// SrcDesc
...
...
@@ -1266,6 +1429,89 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
2
>
;
};
struct
SharedMemTrait
{
// LDS allocation for K
static
constexpr
auto
k_block_desc_k0_n_k1
=
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1
();
// LDS allocation for A and B: be careful of alignment
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
a2_block_desc_k0_m_k1
=
GetA2BlockDescriptor_K0_M_K1
<
Gemm2Params
>
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
GemmDataType
)
>
{};
static
constexpr
auto
k_block_space_size_aligned
=
math
::
integer_least_multiple
(
k_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a2_block_space_size_aligned
=
math
::
integer_least_multiple
(
a2_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
k_block_space_offset
=
0
;
static
constexpr
auto
a_block_space_offset
=
k_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
k_block_space_size_aligned
.
value
;
static
constexpr
auto
a2_block_space_offset
=
k_block_space_size_aligned
.
value
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
reduction_space_offset
=
(
math
::
max
(
a_block_space_size_aligned
.
value
,
b1_block_space_size_aligned
.
value
,
a2_block_space_size_aligned
.
value
)
+
k_block_space_size_aligned
.
value
)
*
sizeof
(
GemmDataType
)
/
sizeof
(
FloatGemmAcc
);
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
k_block_space_size_aligned
.
value
*
sizeof
(
GemmDataType
)
/
D0Loader
::
template
TypeTransform
<
D0DataType
>
::
Size
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
a_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
gemm2_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
a2_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
gemm3_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
a_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatGemmAcc
);
const
index_t
d0_bytes_end
=
(
SharedMemTrait
::
d0_block_space_offset
+
SharedMemTrait
::
d0_block_space_size_aligned
)
*
D0Loader
::
template
TypeTransform
<
D0DataType
>
::
Size0
;
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
gemm2_bytes_end
,
gemm3_bytes_end
,
softmax_bytes_end
,
d0_bytes_end
,
c_block_bytes_end
);
}
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
typename
Block2CTileMap
,
...
...
@@ -1294,7 +1540,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
&
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
VGridDesc_
N0_O_N
1
&
v_grid_desc_
n0_o_n
1
,
const
VGridDesc_
O0_N_O
1
&
v_grid_desc_
o0_n_o
1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
y_grid_desc_mblock_mperblock_oblock_operblock
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
...
...
@@ -1319,7 +1565,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
auto
k_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_k_grid
,
k_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
const
auto
v_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_v_grid
,
v_grid_desc_
n0_o_n
1
.
GetElementSpaceSize
());
p_v_grid
,
v_grid_desc_
o0_n_o
1
.
GetElementSpaceSize
());
const
auto
y_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_grid
,
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetElementSpaceSize
());
const
auto
lse_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
...
@@ -1327,7 +1573,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ygrad_grid
,
ygrad_grid_desc_m0_o_m1
.
GetElementSpaceSize
());
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_vgrad_grid
,
v_grid_desc_
n0_o_n
1
.
GetElementSpaceSize
());
p_vgrad_grid
,
v_grid_desc_
o0_n_o
1
.
GetElementSpaceSize
());
auto
qgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_qgrad_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
kgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
...
@@ -1346,70 +1592,67 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
index_t
num_gemm0_m_block_outer_loop
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
)
/
MPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
MPerBlock
/
Gemm1KPerBlock
;
// 6 GEMM operations are categorized into
3
buckets. SizeK == SizeO == head_dim
//
S_MNK /
dP_MNO Gemm (Gemm0 rcc)
// 6 GEMM operations are categorized into
4
buckets. SizeK == SizeO == head_dim
// dP_MNO Gemm
(Gemm0 rcc)
// dV_NOM / dK_NKM Gemm (Gemm1 rrr)
// Y_MON / dQ_MKN Gemm (Gemm2 crr)
// S_MNK Gemm (Gemm3 rcc)
//
// set up S / dP Gemm (type 1 rcc)
//
// Gemm0: LDS allocation for A and B: be careful of alignment
auto
gemm0_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
Gemm0
::
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
gemm0_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
Gemm0
::
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// LDS allocation for K
auto
k_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
k_block_space_offset
,
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
.
GetElementSpaceSize
());
// Gemm0: gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const
auto
gemm0_gridwise_gemm_pipeline
=
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopScheduler
::
Default
>
();
// S: A matrix blockwise copy
auto
s_gemm_tile_q_blockwise_copy
=
typename
Gemm0
::
template
ABlockwiseCopy
<
decltype
(
q_grid_desc_k0_m_k1
)>(
q_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
MPerBlock
*
(
num_gemm0_m_block_outer_loop
-
1
),
0
),
// will loop over GemmM dimension
a_element_op
,
Gemm0
::
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// S: B matrix blockwise copy
auto
s_gemm_tile_k_blockwise_copy
=
typename
Gemm0
::
template
BBlockwiseCopy
<
decltype
(
k_grid_desc_k0_n_k1
)>(
// K matrix blockwise copy
auto
gemm_tile_k_blockwise_copy
=
typename
GemmBlockwiseCopy
::
template
KBlockwiseCopy
<
decltype
(
k_grid_desc_k0_n_k1
)>(
k_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
Gemm
0
::
b
_block_desc_
b
k0_n_
b
k1
,
Gemm
BlockwiseCopy
::
k
_block_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// S: blockwise gemm
auto
s_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
// TransposeC
// Vgpr allocation for V
auto
v_thread_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
ignore
=
i
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
GemmDataType
,
GemmBlockwiseCopy
::
v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetElementSpaceSize
(),
true
>
{};
},
Number
<
GemmBlockwiseCopy
::
VBlockBufferSize
>
{});
const
auto
v_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
=
GemmBlockwiseCopy
::
MakeVGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
(
v_grid_desc_o0_n_o1
);
auto
s_slash_p_thread_buf
=
s_blockwise_gemm
.
GetCThreadBuffer
();
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
const
auto
s_gemm_tile_q_block_reset_copy_step
=
make_multi_index
(
-
q_grid_desc_k0_m_k1
.
GetLength
(
I0
),
-
MPerBlock
,
0
);
const
auto
s_gemm_tile_k_block_reset_copy_step
=
make_multi_index
(
-
k_grid_desc_k0_n_k1
.
GetLength
(
I0
),
0
,
0
);
// V matrix blockwise copy
auto
gemm_tile_v_blockwise_copy
=
typename
GemmBlockwiseCopy
::
template
VBlockwiseCopy
<
decltype
(
v_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
)>(
v_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_multi_index
(
0
,
0
,
wave_m_n_id
[
I0
],
block_work_idx_n
,
0
,
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
));
//
// set up dP Gemm (type 1 rcc)
//
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
))
/
KPerBlock
);
using
Gemm0
=
Gemm0
<
decltype
(
GemmBlockwiseCopy
::
v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
)
>
;
// dP: transform input and output tensor descriptors
// Gemm0: LDS allocation for A and B: be careful of alignment
auto
gemm0_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
Gemm0
::
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
// dP: transform input tensor descriptors
const
auto
ygrad_grid_desc_o0_m_o1
=
PGradGemmTile_M_N_O
::
MakeYGradGridDesc_O0_M_O1
(
ygrad_grid_desc_m0_o_m1
);
const
auto
v_grid_desc_o0_n_o1
=
PGradGemmTile_M_N_O
::
MakeVGridDesc_O0_N_O1
(
v_grid_desc_n0_o_n1
);
// dP: A matrix blockwise copy
auto
pgrad_gemm_tile_ygrad_blockwise_copy
=
...
...
@@ -1423,30 +1666,47 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// dP: B matrix blockwise copy
auto
pgrad_gemm_tile_v_blockwise_copy
=
typename
Gemm0
::
template
BBlockwiseCopy
<
decltype
(
v_grid_desc_o0_n_o1
)>(
v_grid_desc_o0_n_o1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
Gemm0
::
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// dP: blockwise gemm
// we need separate blockwise gemm object because we need separate thread buffer
auto
pgrad_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
pgrad_blockwise_gemm
.
SetBBlockStartWindow
(
make_tuple
(
0
,
0
,
0
,
0
));
auto
pgrad_thread_buf
=
pgrad_blockwise_gemm
.
GetCThreadBuffer
();
const
auto
pgrad_gemm_tile_ygrad_block_reset_copy_step
=
make_multi_index
(
-
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I0
),
-
MPerBlock
,
0
);
const
auto
pgrad_gemm_tile_v_block_reset_copy_step
=
make_multi_index
(
-
v_grid_desc_o0_n_o1
.
GetLength
(
I0
),
0
,
0
);
const
index_t
num_o_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I0
)
*
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I2
))
/
KPerBlock
);
constexpr
index_t
num_ok_block_main_loop
=
Gemm1NPerBlock
/
KPerBlock
;
//
// set up S Gemm (type 4 rcc)
//
using
Gemm3
=
Gemm3
<
decltype
(
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
)
>
;
// Gemm3: LDS allocation for A and B: be careful of alignment
auto
gemm3_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
Gemm3
::
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
// S: A matrix blockwise copy
auto
s_gemm_tile_q_blockwise_copy
=
typename
Gemm3
::
template
ABlockwiseCopy
<
decltype
(
q_grid_desc_k0_m_k1
)>(
q_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
MPerBlock
*
(
num_gemm0_m_block_outer_loop
-
1
),
0
),
// will loop over GemmM dimension
a_element_op
,
Gemm3
::
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// S: blockwise gemm
auto
s_blockwise_gemm
=
typename
Gemm3
::
BlockwiseGemm
{};
// TransposeC
auto
s_slash_p_thread_buf
=
s_blockwise_gemm
.
GetCThreadBuffer
();
const
auto
s_gemm_tile_q_block_reset_copy_step
=
make_multi_index
(
-
q_grid_desc_k0_m_k1
.
GetLength
(
I0
),
-
MPerBlock
,
0
);
//
// set up dV / dK Gemm (type 2 rrr)
...
...
@@ -1490,7 +1750,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dV: transform input and output tensor descriptors
auto
vgrad_grid_desc_nblock_nperblock_oblock_operblock
=
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
v_grid_desc_
n0_o_n
1
);
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
v_grid_desc_
o0_n_o
1
);
// dK: transform input and output tensor descriptors
const
auto
q_grid_desc_m0_k_m1
=
...
...
@@ -1528,23 +1788,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
//
// set up dQ Gemm (type 3 crr)
//
using
Gemm2
=
Gemm2
<
Gemm2Params
,
decltype
(
pgrad_blockwise_gemm
)
>
;
using
Gemm2
=
Gemm2
<
Gemm2Params
,
decltype
(
pgrad_blockwise_gemm
),
decltype
(
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
)
>
;
// Gemm2: LDS allocation for A and B: be careful of alignment
auto
gemm2_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a2_block_space_offset
,
Gemm2
::
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
gemm2_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
b2_block_space_offset
,
Gemm2
::
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
// // dQ: transform input and output tensor descriptors
// const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
// Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(vgrad_grid_desc_n_o);
// dQ: transform input and output tensor descriptors
const
auto
k_grid_desc_n0_k_n1
=
QGradGemmTile_M_K_N
::
MakeKGridDesc_N0_K_N1
(
k_grid_desc_k0_n_k1
);
auto
gemm2_b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
GemmDataType
>
(
Gemm2
::
b_thread_desc_n0_n1_n2_k0_k1_k2_k3
.
GetElementSpaceSize
());
// dQ: A matrix VGPR-to-LDS blockwise copy
auto
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds
=
...
...
@@ -1553,18 +1807,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
Gemm2
::
MakeAThreadOriginOnBlock_M0_K0_M1_K1_M2_M3_M4_K2
(),
tensor_operation
::
element_wise
::
PassThrough
{}};
// dQ: B matrix global-to-LDS blockwise copy
auto
qgrad_gemm_tile_k_blockwise_copy
=
typename
Gemm2
::
template
BBlockwiseCopy
<
decltype
(
k_grid_desc_n0_k_n1
)>(
k_grid_desc_n0_k_n1
,
make_multi_index
(
n_block_data_idx_on_grid
/
Gemm2Params
::
B_K1
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
Gemm2
::
b_block_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// dQ: blockwise gemm
auto
qgrad_blockwise_gemm
=
typename
Gemm2
::
BlockwiseGemm
{};
qgrad_blockwise_gemm
.
SetBBlockStartWindow
(
make_tuple
(
0
,
0
,
0
,
0
));
auto
k_thread_origin
=
qgrad_blockwise_gemm
.
CalculateBThreadOriginDataIndex
();
// dQ: B matrix LDS-to-VGPR blockwise copy
auto
qgrad_gemm_tile_k_blockwise_copy
=
typename
Gemm2
::
BBlockwiseCopy
{
Gemm2
::
b_block_desc_n0_n1_n2_k0_k1_k2_k3
,
make_multi_index
(
0
,
// nrepeat
k_thread_origin
[
I1
],
// nwave
k_thread_origin
[
I2
],
// nperxdl
0
,
// k0
0
,
// k1
k_thread_origin
[
I3
]
/
Gemm2Params
::
GemmKPack
,
// k2
0
)};
// k3
auto
qgrad_thread_buf
=
qgrad_blockwise_gemm
.
GetCThreadBuffer
();
...
...
@@ -1704,9 +1962,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ZDataType
,
...
...
@@ -1740,6 +1995,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
0
,
//
wave_m_n_id
[
I1
]),
// NPerXdl
tensor_operation
::
element_wise
::
PassThrough
{}};
//
// set up Y dot dY
//
...
...
@@ -1790,7 +2046,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto
ygrad_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
y_dot_ygrad_thread_accum_buf
=
typename
YDotYGrad_M_O
::
DstBufType
{};
auto
y_dot_ygrad_block_accum_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatGemmAcc
*>
(
p_shared
),
MPerBlock
);
static_cast
<
FloatGemmAcc
*>
(
p_shared
)
+
SharedMemTrait
::
reduction_space_offset
,
MPerBlock
);
constexpr
auto
y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
P_M0
,
P_M1
,
P_M2
,
P_M3
,
P_M4
));
...
...
@@ -1821,17 +2078,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto
y_dot_ygrad_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
>
(
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4
.
GetElementSpaceSize
());
if
constexpr
(
Deterministic
)
{
block_sync_lds
();
}
// Initialize dK&dV
kgrad_thread_buf
.
Clear
();
vgrad_thread_buf
.
Clear
();
// gemm0 M loop
index_t
gemm0_m_block_outer_index
=
num_gemm0_m_block_outer_loop
-
1
;
// D0
auto
d0_block_copy_global_to_lds
=
typename
D0Loader
::
D0BlockwiseCopy
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
...
...
@@ -1841,8 +2090,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0ThreadCopy
(
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0Thread
Wise
Copy
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
if
constexpr
(
Deterministic
)
{
block_sync_lds
();
}
// Initialize dK&dV
kgrad_thread_buf
.
Clear
();
vgrad_thread_buf
.
Clear
();
// load k
gemm_tile_k_blockwise_copy
.
Run
(
k_grid_desc_k0_n_k1
,
k_grid_buf
,
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
,
k_block_buf
,
I0
);
// load v
static_for
<
0
,
GemmBlockwiseCopy
::
VBlockBufferSize
,
1
>
{}([
&
](
auto
ii
)
{
gemm_tile_v_blockwise_copy
.
Run
(
v_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
v_grid_buf
,
GemmBlockwiseCopy
::
v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
v_thread_buf
(
Number
<
ii
>
{}));
gemm_tile_v_blockwise_copy
.
MoveSrcSliceWindow
(
v_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
GemmBlockwiseCopy
::
v_block_slice_copy_step
);
});
do
{
auto
m_block_data_idx_on_grid
=
...
...
@@ -1909,22 +2186,55 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
lse_thread_buf
);
// S = Q * K^T
gemm0_gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
q_grid_desc_k0_m_k1
,
Gemm0
::
a_block_desc_ak0_m_ak1
,
s_gemm_tile_q_blockwise_copy
,
q_grid_buf
,
gemm0_a_block_buf
,
Gemm0
::
a_block_slice_copy_step
,
k_grid_desc_k0_n_k1
,
Gemm0
::
b_block_desc_bk0_n_bk1
,
s_gemm_tile_k_blockwise_copy
,
k_grid_buf
,
gemm0_b_block_buf
,
Gemm0
::
b_block_slice_copy_step
,
s_blockwise_gemm
,
s_slash_p_thread_buf
,
num_k_block_main_loop
);
{
// preload data into LDS
s_gemm_tile_q_blockwise_copy
.
RunRead
(
q_grid_desc_k0_m_k1
,
q_grid_buf
);
s_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
q_grid_desc_k0_m_k1
,
Gemm3
::
a_block_slice_copy_step
);
block_sync_lds
();
// wait for previous LDS read
s_slash_p_thread_buf
.
Clear
();
s_gemm_tile_q_blockwise_copy
.
RunWrite
(
Gemm3
::
a_block_desc_ak0_m_ak1
,
gemm3_a_block_buf
);
// main body
if
constexpr
(
HasMainKBlockLoop
)
{
index_t
i
=
0
;
do
{
s_gemm_tile_q_blockwise_copy
.
RunRead
(
q_grid_desc_k0_m_k1
,
q_grid_buf
);
block_sync_lds
();
s_blockwise_gemm
.
Run
(
gemm3_a_block_buf
,
k_block_buf
,
s_slash_p_thread_buf
);
s_blockwise_gemm
.
MoveBBlockSrcSliceWindow
(
Gemm3
::
b_block_slice_copy_step
);
block_sync_lds
();
s_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
q_grid_desc_k0_m_k1
,
Gemm3
::
a_block_slice_copy_step
);
s_gemm_tile_q_blockwise_copy
.
RunWrite
(
Gemm3
::
a_block_desc_ak0_m_ak1
,
gemm3_a_block_buf
);
++
i
;
}
while
(
i
<
(
num_ok_block_main_loop
-
1
));
}
// tail
{
block_sync_lds
();
s_blockwise_gemm
.
Run
(
gemm3_a_block_buf
,
k_block_buf
,
s_slash_p_thread_buf
);
s_blockwise_gemm
.
MoveBBlockSrcSliceWindow
(
Gemm3
::
b_block_slice_copy_step
);
}
}
// end gemm S
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
...
...
@@ -1993,7 +2303,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a
_block_space_offset
,
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0
_block_space_offset
,
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
...
...
@@ -2107,11 +2417,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dV = P_drop^T * dY
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements
//
RunRead(),
RunWrite(), and MoveSliceWindow(). But it is impossible to
//
implement given that
the A1 source buffer is static buffer holding the output
//
of first GEMM and
requires constexpr offset by design. Therefore, we pass
//
tensor coordinate offset
explicitly in Run() below.
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements
RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to
implement given that
// the A1 source buffer is static buffer holding the output
of first GEMM and
// requires constexpr offset by design. Therefore, we pass
tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
vgrad_gemm_tile_ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_m0_o_m1
,
...
...
@@ -2173,22 +2483,51 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
block_sync_lds
();
// dP = dY * V^T
// assume size K == size O so HasMainKBlockLoop is the same
gemm0_gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
ygrad_grid_desc_o0_m_o1
,
Gemm0
::
a_block_desc_ak0_m_ak1
,
// reuse
pgrad_gemm_tile_ygrad_blockwise_copy
,
ygrad_grid_buf
,
gemm0_a_block_buf
,
// reuse
Gemm0
::
a_block_slice_copy_step
,
// reuse
v_grid_desc_o0_n_o1
,
Gemm0
::
b_block_desc_bk0_n_bk1
,
// reuse
pgrad_gemm_tile_v_blockwise_copy
,
v_grid_buf
,
gemm0_b_block_buf
,
// reuse
Gemm0
::
b_block_slice_copy_step
,
// reuse
pgrad_blockwise_gemm
,
pgrad_thread_buf
,
num_o_block_main_loop
);
{
// preload data into LDS
pgrad_gemm_tile_ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_o0_m_o1
,
ygrad_grid_buf
);
pgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_o0_m_o1
,
Gemm0
::
a_block_slice_copy_step
);
block_sync_lds
();
// wait for previous LDS read
pgrad_thread_buf
.
Clear
();
pgrad_gemm_tile_ygrad_blockwise_copy
.
RunWrite
(
Gemm0
::
a_block_desc_ak0_m_ak1
,
gemm0_a_block_buf
);
// main body
if
constexpr
(
num_ok_block_main_loop
>
1
)
{
static_for
<
0
,
num_ok_block_main_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
pgrad_gemm_tile_ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_o0_m_o1
,
ygrad_grid_buf
);
block_sync_lds
();
pgrad_blockwise_gemm
.
Run
(
gemm0_a_block_buf
,
v_thread_buf
(
Number
<
i
>
{}),
pgrad_thread_buf
);
block_sync_lds
();
pgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_o0_m_o1
,
Gemm0
::
a_block_slice_copy_step
);
pgrad_gemm_tile_ygrad_blockwise_copy
.
RunWrite
(
Gemm0
::
a_block_desc_ak0_m_ak1
,
gemm0_a_block_buf
);
});
}
// tail
{
block_sync_lds
();
pgrad_blockwise_gemm
.
Run
(
gemm0_a_block_buf
,
v_thread_buf
(
Number
<
num_ok_block_main_loop
-
1
>
{}),
pgrad_thread_buf
);
}
}
// end gemm dP
// dS = P * (dP - Y_dot_dY)
auto
&
sgrad_thread_buf
=
pgrad_thread_buf
;
...
...
@@ -2220,9 +2559,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dQ = scalar * dS * K
qgrad_thread_buf
.
Clear
();
static_for
<
0
,
num_gemm2_loop
,
1
>
{}([
&
](
auto
gemm2_loop_idx
)
{
// gemm dQ
// load QGrad Gemm B
qgrad_gemm_tile_k_blockwise_copy
.
RunRead
(
k_grid_desc_n0_k_n1
,
k_grid_buf
);
// load QGrad Gemm A
const
auto
sgrad_slice_idx
=
Gemm2
::
ASrcBlockSliceWindowIterator
::
GetIndexTupleOfNumber
(
gemm2_loop_idx
);
...
...
@@ -2245,16 +2581,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
gemm2_a_block_buf
);
}
// k slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// sgrad slice window is moved by loop index
qgrad_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_n0_k_n1
,
Gemm2
::
b_block_slice_copy_step
);
qgrad_gemm_tile_k_blockwise_copy
.
Run
(
Gemm2
::
b_block_desc_n0_n1_n2_k0_k1_k2_k3
,
k_block_buf
,
Gemm2
::
b_thread_desc_n0_n1_n2_k0_k1_k2_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
gemm2_b_thread_buf
);
qgrad_gemm_tile_k_blockwise_copy
.
RunWrite
(
Gemm2
::
b_block_desc_k0_n_k1
,
g
emm2
_
b_block_
buf
);
qgrad_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
Gemm2
::
b_block_desc_n0_n1_n2_k0_k1_k2_k3
,
G
emm2
::
b_block_
slice_copy_step
);
block_sync_lds
();
// sync before read
qgrad_blockwise_gemm
.
Run
(
gemm2_a_block_buf
,
gemm2_b_
block
_buf
,
qgrad_thread_buf
);
qgrad_blockwise_gemm
.
Run
(
gemm2_a_block_buf
,
gemm2_b_
thread
_buf
,
qgrad_thread_buf
);
});
// end gemm dQ
// atomic_add dQ
qgrad_thread_copy_vgpr_to_global
.
Run
(
Gemm2
::
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
...
...
@@ -2267,11 +2604,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dK = scalar * dS^T * Q
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements
//
RunRead(),
RunWrite(), and MoveSliceWindow(). But it is impossible to
//
implement given that
the A1 source buffer is static buffer holding the output
//
of first GEMM and
requires constexpr offset by design. Therefore, we pass
//
tensor coordinate offset
explicitly in Run() below.
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements
RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to
implement given that
// the A1 source buffer is static buffer holding the output
of first GEMM and
// requires constexpr offset by design. Therefore, we pass
tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
kgrad_gemm_tile_q_blockwise_copy
.
RunRead
(
q_grid_desc_m0_k_m1
,
q_grid_buf
);
...
...
@@ -2331,17 +2668,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
s_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
q_grid_desc_k0_m_k1
,
s_gemm_tile_q_block_reset_copy_step
);
// rewind K and step M
s_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_k0_n_k1
,
s_gemm_tile_k_block_reset_copy_step
);
// rewind K
pgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_o0_m_o1
,
pgrad_gemm_tile_ygrad_block_reset_copy_step
);
// rewind O and step M
pgrad_gemm_tile_v_blockwise_copy
.
MoveSrcSliceWindow
(
v_grid_desc_o0_n_o1
,
pgrad_gemm_tile_v_block_reset_copy_step
);
// rewind O
qgrad_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_n0_k_n1
,
Gemm2
::
b_block_desc_n0_n1_n2_k0_k1_k2_k3
,
Gemm2
::
b_block_reset_copy_step
);
// rewind N
kgrad_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
q_grid_desc_m0_k_m1
,
kgrad_gemm_tile_q_block_next_copy_step
);
// step M
...
...
@@ -2353,10 +2684,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
));
yygrad_threadwise_copy
.
MoveSrcSliceWindow
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
make_multi_index
(
-
1
,
0
,
0
,
0
));
s_blockwise_gemm
.
MoveBBlockSrcSliceWindow
(
Gemm3
::
b_block_reset_copy_step
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
while
(
0
<
gemm0_m_block_outer_index
--
);
// end j loop
// shuffle dK&dV and write
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp
View file @
21ef37b4
...
...
@@ -142,8 +142,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
y_grid_desc_mblock_mperblock_nblock_nperblock
,
const
DGridDesc_M
&
d_grid_desc_m
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
,
const
float
p_drop
)
{
const
FloatD
p_dropout
=
type_convert
<
FloatD
>
(
1.0
f
-
p_drop
);
const
tensor_operation
::
element_wise
::
Scale
scale_p_dropout
(
p_dropout
);
const
auto
y_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_grid
,
y_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
...
@@ -247,7 +251,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
FloatD
,
decltype
(
d_thread_desc_mblock_m1
),
decltype
(
d_grid_desc_mblock_mperblock
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
Scale
,
Sequence
<
1
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
...
...
@@ -258,7 +262,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
d_grid_desc_mblock_mperblock
,
make_multi_index
(
block_work_idx_m
,
// mblock
get_thread_local_1d_id
()),
// mperblock
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}
};
scale_p_dropout
};
// copy from VGPR to Global
d_thread_copy_vgpr_to_global
.
Run
(
d_thread_desc_mblock_m1
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment