Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
825f7f02
Commit
825f7f02
authored
Dec 20, 2022
by
Anthony Chang
Browse files
refactor Gemm1
parent
c798cff9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
475 additions
and
456 deletions
+475
-456
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+464
-456
include/ck/utility/statically_indexed_array_multi_index.hpp
include/ck/utility/statically_indexed_array_multi_index.hpp
+11
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
825f7f02
...
...
@@ -108,140 +108,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static
constexpr
auto
B1K0
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
// VGrad Gemm
template
<
index_t
Sum_M_
=
MPerXdl
*
2
>
struct
VGradGemmTile_N_O_M_
{
static
constexpr
index_t
Free0_N
=
NPerBlock
;
static
constexpr
index_t
Free1_O
=
Gemm1NPerBlock
;
static
constexpr
index_t
Sum_M
=
Sum_M_
;
static
constexpr
index_t
P_M1
=
8
;
// P will be row-major
static
constexpr
index_t
P_M0
=
Sum_M
/
P_M1
;
static
constexpr
index_t
P_LdsPad
=
0
;
// how many multiples of M1 per N * M1 elements
static
constexpr
index_t
YGrad_M1
=
2
;
// dY assumed row-major, typically =2 for fp16
static
constexpr
index_t
YGrad_M0
=
Sum_M
/
YGrad_M1
;
static
constexpr
index_t
YGrad_LdsPad
=
0
;
// how many multiples of M1 per N * M1 elements
static_assert
(
Sum_M
%
MPerXdl
==
0
,
""
);
static
constexpr
index_t
YGrad_SrcVectorDim
=
1
;
// Free1_O dimension
static
constexpr
index_t
YGrad_SrcScalarPerVector
=
4
;
static
constexpr
index_t
GemmNWave
=
2
;
static
constexpr
index_t
GemmOWave
=
BlockSize
/
get_warp_size
()
/
GemmNWave
;
static
constexpr
index_t
GemmNRepeat
=
Free0_N
/
GemmNWave
/
MPerXdl
;
static
constexpr
index_t
GemmORepeat
=
Free1_O
/
GemmOWave
/
NPerXdl
;
static
constexpr
index_t
GemmMPack
=
math
::
max
(
math
::
lcm
(
P_M1
,
YGrad_M1
),
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
using
YGrad_BlockSliceLengths
=
Sequence
<
YGrad_M0
,
Free1_O
,
YGrad_M1
>
;
using
YGrad_ThreadClusterLengths
=
Sequence
<
BlockSize
/
(
Free1_O
/
YGrad_SrcScalarPerVector
),
Free1_O
/
YGrad_SrcScalarPerVector
,
1
>
;
using
YGrad_ThreadClusterArrangeOrder
=
Sequence
<
0
,
2
,
1
>
;
__host__
__device__
static
constexpr
auto
GetPBlockDescriptor_M0_N_M1
()
{
constexpr
index_t
P_M0
=
Sum_M
/
P_M1
;
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
P_M0
>
{},
Number
<
Free0_N
>
{},
Number
<
P_M1
>
{}),
make_tuple
(
Number
<
Free0_N
+
P_LdsPad
>
{}
*
Number
<
P_M1
>
{},
Number
<
P_M1
>
{},
I1
));
}
__host__
__device__
static
constexpr
auto
GetYGradBlockDescriptor_M0_O_M1
()
{
constexpr
index_t
YGrad_M0
=
Sum_M
/
YGrad_M1
;
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
YGrad_M0
>
{},
Number
<
Free1_O
>
{},
Number
<
YGrad_M1
>
{}),
make_tuple
(
Number
<
Free1_O
+
YGrad_LdsPad
>
{}
*
Number
<
YGrad_M1
>
{},
Number
<
YGrad_M1
>
{},
I1
));
}
__host__
__device__
static
constexpr
auto
GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2
()
{
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr
index_t
m
=
Sum_M
-
1
;
constexpr
index_t
m2
=
m
%
MPerXdl
;
constexpr
index_t
m1
=
m
/
MPerXdl
%
Gemm0MWaves
;
constexpr
index_t
m0
=
m
/
MPerXdl
/
Gemm0MWaves
%
MXdlPerWave
;
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr
index_t
n
=
Free0_N
-
1
;
constexpr
index_t
n2
=
n
%
NPerXdl
;
constexpr
index_t
n1
=
n
/
NPerXdl
%
Gemm0NWaves
;
constexpr
index_t
n0
=
n
/
NPerXdl
/
Gemm0NWaves
%
NXdlPerWave
;
// assume 256 decomposed into 2 x 4 x 32
// 1d idx ( 32 - 1) -> 3d idx 0, 0, 31 -> 3d dim 1 x 1 x 32
// 1d idx (256 - 1) -> 3d idx 1, 3, 31 -> 3d dim 2 x 4 x 32
return
Sequence
<
m0
,
n0
,
m1
,
n1
,
m2
,
n2
>
{}
+
Sequence
<
1
,
1
,
1
,
1
,
1
,
1
>
{};
}
__host__
__device__
static
constexpr
auto
GetPBlockSliceLengths_M0_N0_M1_N1
()
{
return
generate_sequence_v2
(
[](
auto
I
)
{
return
GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2
().
At
(
I
);
},
Number
<
4
>
{});
}
};
using
VGradGemmTile_N_O_M
=
VGradGemmTile_N_O_M_
<>
;
// tune later
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
{
static
constexpr
index_t
SrcScalarPerVetor
=
16
/
sizeof
(
DataType
);
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVetor
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
static
constexpr
auto
ThreadSliceLength_O
=
Number
<
SrcScalarPerVetor
>
{};
static
constexpr
auto
ThreadSliceLength_M
=
Number
<
BlockSliceLength_M_
*
ThreadClusterLength_O
/
BlockSize_
>
{};
static_assert
(
ThreadClusterLength_O
*
ThreadSliceLength_O
==
BlockSliceLength_O_
,
""
);
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
using
SrcBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
ThreadSliceLength_M
*
ThreadSliceLength_O
,
true
>
;
using
DstBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
,
ThreadSliceLength_M
,
true
>
;
};
using
YDotYGrad_M_O
=
YDotYGrad_M_O_
<
BlockSize
,
MPerBlock
,
Gemm1NPerBlock
>
;
// QGrad Gemm
// KGrad Gemm
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
MXdlPerWave
,
MWaves
,
MPerXdl
>
(
ABlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
...
...
@@ -274,6 +145,33 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
template
<
typename
AccThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
>
__host__
__device__
static
constexpr
auto
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1
(
const
AccThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
&
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
)
{
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// n4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
const
auto
m0
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
const
auto
n0
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
const
auto
m1
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
const
auto
n1
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
const
auto
m2
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
const
auto
n2
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
const
auto
n3
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
const
auto
n4
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
return
transform_tensor_descriptor
(
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
n0
,
n1
,
n2
,
n3
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
m0
,
m1
,
m2
)),
make_pass_through_transform
(
n4
)),
make_tuple
(
Sequence
<
1
,
3
,
5
,
6
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
}
__host__
__device__
static
constexpr
auto
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B1 matrix in LDS memory, dst of blockwise copy
...
...
@@ -345,11 +243,22 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
v_grid_desc_n0_o_n1
.
GetLength
(
I1
);
// This assumption redues implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if
(
Gemm1N
!=
K
)
{
std
::
cerr
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
}
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
Gemm1N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
{
return
false
;
...
...
@@ -446,137 +355,329 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
struct
SharedMemTrait
// P / dP Gemm (type 1 rcr)
struct
Gemm0
{
//
LDS allocation for A and B: be careful of alignment
//
A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
MXdlPerWave
,
MWaves
,
MPerXdl
>
(
ABlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
template
<
typename
GridDesc_K0_M_K1
>
using
ABlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
GridDesc_K0_M_K1
,
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
template
<
typename
GridDesc_K0_N_K1
>
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_ak0_m_ak1
)),
decltype
(
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
b_block_desc_bk0_n_bk1
)),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
,
true
>
;
// TransposeC
static
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
};
// Y / dQ Gemm (type 2 rrr)
template
<
typename
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
,
typename
ASrcBlockDesc_M0_N0_M1_N1_M2_N2_N3_N4
>
struct
Gemm1
{
private:
static
constexpr
auto
m0
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I0
);
static
constexpr
auto
n0
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I1
);
static
constexpr
auto
m1
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I2
);
static
constexpr
auto
n1
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I3
);
static
constexpr
auto
m2
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I4
);
static
constexpr
auto
n2
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I5
);
static
constexpr
auto
n3
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I6
);
static
constexpr
auto
n4
=
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I7
);
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
static
constexpr
auto
N3
=
ASrcBlockDesc_M0_N0_M1_N1_M2_N2_N3_N4
{}.
GetLength
(
I6
);
public:
static
constexpr
auto
AThreadSliceLength_K0
=
Number
<
Gemm1KPerBlock
/
n4
/
N3
>
{};
static
constexpr
auto
AThreadSliceLength_M
=
Number
<
m0
*
m1
*
m2
>
{};
static
constexpr
auto
AThreadSliceLength_K1
=
Number
<
n4
>
{};
static
constexpr
auto
acc_thread_desc_k0_m_k1
=
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1
(
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{});
static
constexpr
auto
a_thread_desc_k0_m_k1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
AThreadSliceLength_K0
,
AThreadSliceLength_M
,
AThreadSliceLength_K1
));
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
p_block_desc_m0_n_m1
=
VGradGemmTile_N_O_M
::
GetPBlockDescriptor_M0_N_M1
();
static
constexpr
auto
ygrad_block_desc_m0_o_m1
=
VGradGemmTile_N_O_M
::
GetYGradBlockDescriptor_M0_O_M1
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
DataType
)
>
{}
;
static
constexpr
auto
ASrcScalarPerVector
=
n4
;
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
p_block_space_size_aligned
=
math
::
integer_least_multiple
(
p_block_desc_m0_n_m1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
ygrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
ygrad_block_desc_m0_o_m1
.
GetElementSpaceSize
(),
max_lds_align
);
using
AThreadSliceLengths_K0_M_K1
=
decltype
(
a_thread_desc_k0_m_k1
.
GetLengths
());
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatGemmAcc
,
DataType
,
decltype
(
acc_thread_desc_k0_m_k1
),
decltype
(
a_thread_desc_k0_m_k1
),
tensor_operation
::
element_wise
::
PassThrough
,
AThreadSliceLengths_K0_M_K1
,
Sequence
<
1
,
0
,
2
>
,
2
,
ASrcScalarPerVector
>
;
template
<
typename
GridDesc_K0_N_K1
>
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
B1K0
,
Gemm1NPerBlock
,
B1K1
>
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc_bk0_n_bk1
),
B1BlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
B1BlockTransferSrcVectorDim
,
2
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
1
,
1
,
B1ThreadTransferSrcResetCoordinateAfterRun
,
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
// for a_block_slice_copy_step to be able to address static buffers, it MUST be a
// tuple-based container as well as containing ONLY integral constants
static
constexpr
auto
a_block_slice_copy_step
=
make_tuple
(
AThreadSliceLength_K0
,
I0
,
I0
);
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
Gemm1KPerBlock
/
B1K1
,
0
,
0
);
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static
constexpr
index_t
GemmKPack
=
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
a_thread_desc_k0_m_k1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
(
a_thread_desc_k0_m_k1
)),
decltype
(
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
(
b_block_desc_bk0_n_bk1
)),
MPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
Gemm1NXdlPerWave
,
GemmKPack
,
true
,
// TransposeC
GemmKPack
,
// AMmaKStride
GemmKPack
*
XdlopsGemm
<
DataType
,
MPerXdl
,
NPerXdl
,
GemmKPack
,
false
>
{}
.
K0PerXdlops
/* BMmaKStride */
>
;
};
// dV / dK Gemm (type 3 crr)
template
<
index_t
Sum_M_
=
MPerXdl
*
2
>
struct
VGradGemmTile_N_O_M_
{
static
constexpr
index_t
Free0_N
=
NPerBlock
;
static
constexpr
index_t
Free1_O
=
Gemm1NPerBlock
;
static
constexpr
index_t
Sum_M
=
Sum_M_
;
static
constexpr
index_t
P_M1
=
8
;
// P will be row-major
static
constexpr
index_t
P_M0
=
Sum_M
/
P_M1
;
static
constexpr
index_t
P_LdsPad
=
0
;
// how many multiples of M1 per N * M1 elements
static
constexpr
index_t
YGrad_M1
=
2
;
// dY assumed row-major, typically =2 for fp16
static
constexpr
index_t
YGrad_M0
=
Sum_M
/
YGrad_M1
;
static
constexpr
index_t
YGrad_LdsPad
=
0
;
// how many multiples of M1 per N * M1 elements
static_assert
(
Sum_M
%
MPerXdl
==
0
,
""
);
static
constexpr
index_t
YGrad_SrcVectorDim
=
1
;
// Free1_O dimension
static
constexpr
index_t
YGrad_SrcScalarPerVector
=
4
;
static
constexpr
index_t
GemmNWave
=
2
;
static
constexpr
index_t
GemmOWave
=
BlockSize
/
get_warp_size
()
/
GemmNWave
;
static
constexpr
index_t
GemmNRepeat
=
Free0_N
/
GemmNWave
/
MPerXdl
;
static
constexpr
index_t
GemmORepeat
=
Free1_O
/
GemmOWave
/
NPerXdl
;
static
constexpr
index_t
GemmMPack
=
math
::
max
(
math
::
lcm
(
P_M1
,
YGrad_M1
),
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
using
YGrad_BlockSliceLengths
=
Sequence
<
YGrad_M0
,
Free1_O
,
YGrad_M1
>
;
using
YGrad_ThreadClusterLengths
=
Sequence
<
BlockSize
/
(
Free1_O
/
YGrad_SrcScalarPerVector
),
Free1_O
/
YGrad_SrcScalarPerVector
,
1
>
;
using
YGrad_ThreadClusterArrangeOrder
=
Sequence
<
0
,
2
,
1
>
;
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
static
constexpr
auto
p_block_space_offset
=
0
;
static
constexpr
auto
ygrad_block_space_offset
=
p_block_space_size_aligned
.
value
;
__host__
__device__
static
constexpr
auto
GetPBlockDescriptor_M0_N_M1
()
{
constexpr
index_t
P_M0
=
Sum_M
/
P_M1
;
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
P_M0
>
{},
Number
<
Free0_N
>
{},
Number
<
P_M1
>
{}),
make_tuple
(
Number
<
Free0_N
+
P_LdsPad
>
{}
*
Number
<
P_M1
>
{},
Number
<
P_M1
>
{},
I1
));
}
__host__
__device__
static
constexpr
auto
GetYGradBlockDescriptor_M0_O_M1
()
{
constexpr
index_t
YGrad_M0
=
Sum_M
/
YGrad_M1
;
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
YGrad_M0
>
{},
Number
<
Free1_O
>
{},
Number
<
YGrad_M1
>
{}),
make_tuple
(
Number
<
Free1_O
+
YGrad_LdsPad
>
{}
*
Number
<
YGrad_M1
>
{},
Number
<
YGrad_M1
>
{},
I1
));
}
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
__host__
__device__
static
constexpr
auto
GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2
()
{
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr
index_t
m
=
Sum_M
-
1
;
constexpr
index_t
m2
=
m
%
MPerXdl
;
constexpr
index_t
m1
=
m
/
MPerXdl
%
Gemm0MWaves
;
constexpr
index_t
m0
=
m
/
MPerXdl
/
Gemm0MWaves
%
MXdlPerWave
;
static
constexpr
auto
reduction_space_offset
=
0
;
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr
index_t
n
=
Free0_N
-
1
;
constexpr
index_t
n2
=
n
%
NPerXdl
;
constexpr
index_t
n1
=
n
/
NPerXdl
%
Gemm0NWaves
;
constexpr
index_t
n0
=
n
/
NPerXdl
/
Gemm0NWaves
%
NXdlPerWave
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
// assume 256 decomposed into 2 x 4 x 32
// 1d idx ( 32 - 1) -> 3d idx 0, 0, 31 -> 3d dim 1 x 1 x 32
// 1d idx (256 - 1) -> 3d idx 1, 3, 31 -> 3d dim 2 x 4 x 32
return
Sequence
<
m0
,
n0
,
m1
,
n1
,
m2
,
n2
>
{}
+
Sequence
<
1
,
1
,
1
,
1
,
1
,
1
>
{};
}
// P / dP Gemm (type 1 rcr)
struct
Gemm0
{
private:
static
constexpr
auto
a_block_desc
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
b_block_desc
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
__host__
__device__
static
constexpr
auto
GetPBlockSliceLengths_M0_N0_M1_N1
()
{
return
generate_sequence_v2
(
[](
auto
I
)
{
return
GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2
().
At
(
I
);
},
Number
<
4
>
{});
}
};
public:
template
<
typename
GridDesc_K0_M_K1
>
using
ABlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
GridDesc_K0_M_K1
,
decltype
(
a_block_desc
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
using
VGradGemmTile_N_O_M
=
VGradGemmTile_N_O_M_
<>
;
// tune later
template
<
typename
GridDesc_K0_N_K1
>
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
{
static
constexpr
index_t
SrcScalarPerVetor
=
16
/
sizeof
(
DataType
);
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVetor
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
static
constexpr
auto
ThreadSliceLength_O
=
Number
<
SrcScalarPerVetor
>
{};
static
constexpr
auto
ThreadSliceLength_M
=
Number
<
BlockSliceLength_M_
*
ThreadClusterLength_O
/
BlockSize_
>
{};
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
_assert
(
ThreadClusterLength_O
*
ThreadSliceLength_O
==
BlockSliceLength_O_
,
""
);
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
a_block_desc
),
decltype
(
b_block_desc
),
decltype
(
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc
)),
decltype
(
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
b_block_desc
)),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
,
true
>
;
// TransposeC
using
SrcBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
ThreadSliceLength_M
*
ThreadSliceLength_O
,
true
>
;
static
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
)
;
using
DstBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
,
ThreadSliceLength_M
,
true
>
;
};
using
YDotYGrad_M_O
=
YDotYGrad_M_O_
<
BlockSize
,
MPerBlock
,
Gemm1NPerBlock
>
;
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major)
struct
PGradGemmTile_M_N_O
...
...
@@ -631,13 +732,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
struct
QGradGemmTile_M_K_N
{
template
<
typename
QGridDesc_K0_M_K1_
>
__device__
static
const
auto
MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock
(
const
QGridDesc_K0_M_K1_
&
q_grid_desc_k0_m_k1
)
__device__
static
const
auto
MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock
(
const
QGridDesc_K0_M_K1_
&
q_grid_desc_k0_m_k1
)
{
const
auto
K0
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
K1
=
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
K
=
K0
*
K1
;
const
auto
K
=
K0
*
K1
;
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
KBlock
=
K
/
Gemm1NPerBlock
;
// NOTE: QGrad gemm is similar to Y gemm
...
...
@@ -659,7 +760,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
template
<
typename
SGradThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4_
>
__device__
static
const
auto
MakeSGradThreadDesc_N0_M_N1
(
const
SGradThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4_
&
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
)
MakeSGradThreadDesc_N0_M_N1
(
const
SGradThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4_
&
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
)
{
constexpr
auto
m0
=
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
n0
=
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
...
...
@@ -673,8 +775,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr
auto
sgrad_thread_desc_n0_m_n1
=
transform_tensor_descriptor
(
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
n0
,
n1
,
n2
,
n3
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
m0
,
m1
,
m2
)),
make_pass_through_transform
(
n4
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
m0
,
m1
,
m2
)),
make_pass_through_transform
(
n4
)),
make_tuple
(
Sequence
<
1
,
3
,
5
,
6
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
...
...
@@ -703,6 +805,52 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
};
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
p_block_desc_m0_n_m1
=
VGradGemmTile_N_O_M
::
GetPBlockDescriptor_M0_N_M1
();
static
constexpr
auto
ygrad_block_desc_m0_o_m1
=
VGradGemmTile_N_O_M
::
GetYGradBlockDescriptor_M0_O_M1
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
DataType
)
>
{};
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
p_block_space_size_aligned
=
math
::
integer_least_multiple
(
p_block_desc_m0_n_m1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
ygrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
ygrad_block_desc_m0_o_m1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
static
constexpr
auto
p_block_space_offset
=
0
;
static
constexpr
auto
ygrad_block_space_offset
=
p_block_space_size_aligned
.
value
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
reduction_space_offset
=
0
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
,
typename
C0MatrixMask
,
...
...
@@ -774,19 +922,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// set up P / dP Gemm (type 1 rcr)
//
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
typename
Gemm0
::
template
ABlockwiseCopy
<
decltype
(
q_grid_desc_k0_m_k1
)>(
q_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
Gemm0
::
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
...
...
@@ -796,7 +938,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
k_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
b_element_op
,
b_block_desc_bk0_n_bk1
,
Gemm0
::
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
...
...
@@ -807,14 +949,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// LDS allocation for A and B: be careful of alignment
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
Gemm0
::
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
Gemm0
::
a_block_slice_copy_step
;
constexpr
auto
b_block_slice_copy_step
=
Gemm0
::
b_block_slice_copy_step
;
Gemm0
::
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
const
auto
a_block_reset_copy_step
=
make_multi_index
(
-
q_grid_desc_k0_m_k1
.
GetLength
(
I0
),
0
,
0
);
...
...
@@ -828,95 +967,31 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
LoopScheduler
::
Default
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
))
/
KPerBlock
);
(
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
))
/
KPerBlock
);
//
// set up Y / dQ Gemm (type 2 rrr)
//
using
Gemm1
=
Gemm1
<
decltype
(
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()),
decltype
(
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
())
>
;
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr
auto
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
m0
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
n0
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
m1
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
n1
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
constexpr
auto
m2
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
constexpr
auto
n2
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
constexpr
auto
n3
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
n4
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
b1_block_slice_copy_step
=
make_multi_index
(
Gemm1KPerBlock
/
B1K1
,
0
,
0
);
constexpr
auto
acc_thread_desc_k0_m_k1
=
Gemm1
::
acc_thread_desc_k0_m_k1
;
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// n4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
constexpr
auto
acc_thread_desc_k0_m_k1
=
transform_tensor_descriptor
(
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
n0
,
n1
,
n2
,
n3
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
m0
,
m1
,
m2
)),
make_pass_through_transform
(
n4
)),
make_tuple
(
Sequence
<
1
,
3
,
5
,
6
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// A1 matrix in AccVGPR
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
constexpr
auto
AccN3
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLength
(
I6
);
constexpr
auto
A1ThreadSlice_K0_M_K1
=
make_tuple
(
Number
<
Gemm1KPerBlock
/
n4
/
AccN3
>
{},
Number
<
m0
*
m1
*
m2
>
{},
Number
<
n4
>
{});
constexpr
auto
A1ThreadSliceK0
=
A1ThreadSlice_K0_M_K1
[
I0
];
constexpr
auto
A1ThreadSliceM
=
A1ThreadSlice_K0_M_K1
[
I1
];
constexpr
auto
A1ThreadSliceK1
=
A1ThreadSlice_K0_M_K1
[
I2
];
constexpr
auto
a1_thread_desc_k0_m_k1
=
make_naive_tensor_descriptor
(
A1ThreadSlice_K0_M_K1
,
make_tuple
(
A1ThreadSliceM
*
A1ThreadSliceK1
,
A1ThreadSliceK1
,
I1
));
// A1 matrix in accumulator VGPR, dst of blockwise copy
constexpr
auto
a1_thread_desc_k0_m_k1
=
Gemm1
::
a_thread_desc_k0_m_k1
;
// B1 matrix in LDS memory, dst of blockwise copy
constexpr
auto
b1_block_desc_bk0_n_bk1
=
Ge
tB1B
lock
D
esc
riptor_BK0PerBlock_NPerBlock_BK1
()
;
constexpr
auto
b1_block_desc_bk0_n_bk1
=
Ge
mm1
::
b_b
lock
_d
esc
_bk0_n_bk1
;
// A1 matrix blockwise copy
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatGemmAcc
,
DataType
,
decltype
(
acc_thread_desc_k0_m_k1
),
decltype
(
a1_thread_desc_k0_m_k1
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
A1ThreadSliceK0
,
A1ThreadSliceM
,
A1ThreadSliceK1
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
n4
>
{
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
a1_blockwise_copy
=
typename
Gemm1
::
ABlockwiseCopy
{
tensor_operation
::
element_wise
::
PassThrough
{}};
// B1 matrix blockwise copy
auto
b1_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
B1K0
,
Gemm1NPerBlock
,
B1K1
>
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
decltype
(
v_grid_desc_n0_o_n1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
B1BlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
B1BlockTransferSrcVectorDim
,
2
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
1
,
1
,
B1ThreadTransferSrcResetCoordinateAfterRun
,
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
typename
Gemm1
::
template
BBlockwiseCopy
<
decltype
(
v_grid_desc_n0_o_n1
)>(
v_grid_desc_n0_o_n1
,
make_multi_index
(
0
,
o_block_data_idx_on_grid
,
0
),
b1_element_op
,
...
...
@@ -927,44 +1002,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
auto
a1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
>
(
a1_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
// reuse LDS space for gemm0's b_block_buf
auto
b1_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
constexpr
index_t
Gemm1KPack
=
Gemm1
::
GemmKPack
;
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
a1_thread_desc_k0_m_k1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
decltype
(
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
(
a1_thread_desc_k0_m_k1
)),
decltype
(
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
(
b1_block_desc_bk0_n_bk1
)),
MPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
Gemm1NXdlPerWave
,
Gemm1KPack
,
true
,
// TransposeC
Gemm1KPack
,
// AMmaKStride
Gemm1KPack
*
XdlopsGemm
<
DataType
,
MPerXdl
,
NPerXdl
,
Gemm1KPack
,
false
>
{}.
K0PerXdlops
>
{
// BMmaKStride
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
auto
gemm1_blockwise_gemm
=
typename
Gemm1
::
BlockwiseGemm
{
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
auto
acc1_thread_buf
=
gemm1_blockwise_gemm
.
GetCThreadBuffer
();
...
...
@@ -1003,6 +1048,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
chain_tensor_adaptors
(
m0_n_m1_to_m_n_adaptor
,
threadid_to_m0_n_m1_adaptor
);
// get acc0 2D thread cluster & 2D thread slice
constexpr
auto
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
m0
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
n0
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
m1
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
n1
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
constexpr
auto
m2
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
constexpr
auto
n2
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
constexpr
auto
n3
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
n4
=
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
thread_cluster_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
tm0
*
tm1
*
tm2
,
tn0
*
tn1
*
tn2
*
tn3
*
tn4
));
constexpr
auto
thread_slice_desc_m_n
=
...
...
@@ -1422,7 +1478,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
ygrad_grid_desc_o0_m_o1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
a_block_desc_ak0_m_ak1
,
Gemm0
::
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
...
...
@@ -1432,7 +1488,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
v_grid_desc_o0_n_o1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
tensor_operation
::
element_wise
::
PassThrough
{},
b_block_desc_bk0_n_bk1
,
Gemm0
::
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
...
...
@@ -1461,6 +1517,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
//
// set up dQ Gemm (type 2 rrr)
//
// transform input and output tensor descriptors
const
auto
k_grid_desc_n0_k_n1
=
QGradGemmTile_M_K_N
::
MakeKGridDesc_N0_K_N1
(
k_grid_desc_k0_n_k1
);
auto
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
=
...
...
@@ -1468,41 +1526,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
q_grid_desc_k0_m_k1
);
// dQ Gemm A matrix blockwise copy
auto
qgrad_gemm_tile_sgrad_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatGemmAcc
,
DataType
,
decltype
(
acc_thread_desc_k0_m_k1
),
// reuse desc
decltype
(
a1_thread_desc_k0_m_k1
),
// reuse desc
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
A1ThreadSliceK0
,
A1ThreadSliceM
,
A1ThreadSliceK1
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
n4
>
{
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
qgrad_gemm_tile_sgrad_blockwise_copy
=
typename
Gemm1
::
ABlockwiseCopy
{
tensor_operation
::
element_wise
::
PassThrough
{}};
// dQ Gemm B matrix blockwise copy
auto
qgrad_gemm_tile_k_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
B1K0
,
Gemm1NPerBlock
,
B1K1
>
,
// reuse from V
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
// reuse from V
B1BlockTransferThreadClusterArrangeOrder
,
// reuse from V
DataType
,
DataType
,
decltype
(
k_grid_desc_n0_k_n1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
// reuse from V
B1BlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
B1BlockTransferSrcVectorDim
,
2
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
1
,
1
,
B1ThreadTransferSrcResetCoordinateAfterRun
,
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
typename
Gemm1
::
template
BBlockwiseCopy
<
decltype
(
k_grid_desc_n0_k_n1
)>(
k_grid_desc_n0_k_n1
,
make_multi_index
(
0
,
o_block_data_idx_on_grid
,
0
),
b1_element_op
,
...
...
@@ -1510,32 +1539,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
qgrad_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
a1_thread_desc_k0_m_k1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
decltype
(
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
(
a1_thread_desc_k0_m_k1
)),
decltype
(
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
(
b1_block_desc_bk0_n_bk1
)),
MPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
Gemm1NXdlPerWave
,
Gemm1KPack
,
true
,
// TransposeC
Gemm1KPack
,
// AMmaKStride
Gemm1KPack
*
XdlopsGemm
<
DataType
,
MPerXdl
,
NPerXdl
,
Gemm1KPack
,
false
>
{}.
K0PerXdlops
>
{
// BMmaKStride
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
auto
qgrad_blockwise_gemm
=
typename
Gemm1
::
BlockwiseGemm
{
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
auto
qgrad_thread_buf
=
qgrad_blockwise_gemm
.
GetCThreadBuffer
();
//
// calculate
y
dot
ygra
d
// calculate
Y
dot d
Y
//
// clear accum buffers
...
...
@@ -1632,8 +1642,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple
(
I0
,
I0
,
I0
,
I0
),
lse_thread_buf
);
const
index_t
num_gemm1_k_block_outer_loop
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
)
/
NPerBlock
;
const
index_t
num_gemm1_k_block_outer_loop
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
)
/
NPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
// Initialize dQ
...
...
@@ -1652,17 +1661,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
// P = Q * K^T
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
q_grid_desc_k0_m_k1
,
a_block_desc_ak0_m_ak1
,
Gemm0
::
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
q_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
Gemm0
::
a_block_slice_copy_step
,
k_grid_desc_k0_n_k1
,
b_block_desc_bk0_n_bk1
,
Gemm0
::
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
k_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
Gemm0
::
b_block_slice_copy_step
,
s_blockwise_gemm
,
s_slash_p_thread_buf
,
num_k_block_main_loop
);
...
...
@@ -1857,17 +1866,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds
();
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
ygrad_grid_desc_o0_m_o1
,
a_block_desc_ak0_m_ak1
,
// reuse
Gemm0
::
a_block_desc_ak0_m_ak1
,
// reuse
pgrad_gemm_tile_ygrad_blockwise_copy
,
ygrad_grid_buf
,
a_block_buf
,
// reuse
a_block_slice_copy_step
,
// reuse
a_block_buf
,
// reuse
Gemm0
::
a_block_slice_copy_step
,
// reuse
v_grid_desc_o0_n_o1
,
b_block_desc_bk0_n_bk1
,
// reuse
Gemm0
::
b_block_desc_bk0_n_bk1
,
// reuse
pgrad_gemm_tile_v_blockwise_copy
,
v_grid_buf
,
b_block_buf
,
// reuse
b_block_slice_copy_step
,
// reuse
b_block_buf
,
// reuse
Gemm0
::
b_block_slice_copy_step
,
// reuse
pgrad_blockwise_gemm
,
pgrad_thread_buf
,
num_o_block_main_loop
);
...
...
@@ -1897,8 +1906,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr
auto
n
=
pgrad_thread_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
pgrad_thread_idx
)[
I1
];
// dS and P has same thread buf layout
sgrad_thread_buf
(
i
)
=
s_slash_p_thread_buf
[
i
]
*
(
pgrad_thread_buf
[
i
]
-
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
sgrad_thread_buf
(
i
)
=
s_slash_p_thread_buf
[
i
]
*
(
pgrad_thread_buf
[
i
]
-
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
});
#if 0
...
...
@@ -1927,7 +1936,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
qgrad_gemm_tile_k_blockwise_copy
.
RunRead
(
k_grid_desc_n0_k_n1
,
k_grid_buf
);
qgrad_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_n0_k_n1
,
b
1
_block_slice_copy_step
);
Gemm1
::
b_block_slice_copy_step
);
block_sync_lds
();
// wait for previous LDS read
...
...
@@ -1944,13 +1953,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
if
constexpr
(
num_gemm1_k_block_inner_loop
>
1
)
{
static_for
<
0
,
num_gemm1_k_block_inner_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
qgrad_gemm_tile_sgrad_blockwise_copy
.
Run
(
acc_thread_desc_k0_m_k1
,
make_tuple
(
Number
<
i
*
A1ThreadSliceK0
>
{},
I0
,
I0
),
sgrad_thread_buf
,
a1_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
);
qgrad_gemm_tile_sgrad_blockwise_copy
.
Run
(
acc_thread_desc_k0_m_k1
,
Gemm1
::
a_block_slice_copy_step
*
i
,
sgrad_thread_buf
,
a1_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
);
#if 0
if(hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
...
...
@@ -1971,18 +1979,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds
();
qgrad_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_n0_k_n1
,
b1
_block_slice_copy_step
);
qgrad_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_n0_k_n1
,
Gemm1
::
b
_block_slice_copy_step
);
qgrad_gemm_tile_k_blockwise_copy
.
RunWrite
(
b1_block_desc_bk0_n_bk1
,
b1_block_buf
);
qgrad_gemm_tile_k_blockwise_copy
.
RunWrite
(
b1_block_desc_bk0_n_bk1
,
b1_block_buf
);
});
}
// tail
{
qgrad_gemm_tile_sgrad_blockwise_copy
.
Run
(
acc_thread_desc_k0_m_k1
,
make_tuple
(
Number
<
(
num_gemm1_k_block_inner_loop
-
1
)
*
A1ThreadSliceK0
>
{},
I0
,
I0
),
Gemm1
::
a_block_slice_copy_step
*
Number
<
num_gemm1_k_block_inner_loop
-
1
>
{},
sgrad_thread_buf
,
a1_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
...
...
@@ -2011,8 +2019,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
// TODO ANT:
// shuffle dQ and write
// TODO ANT:
// shuffle dQ and write
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
...
...
include/ck/utility/statically_indexed_array_multi_index.hpp
View file @
825f7f02
...
...
@@ -100,6 +100,17 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
return
r
;
}
template
<
typename
...
Xs
,
index_t
N
>
__host__
__device__
constexpr
auto
operator
*
(
const
Tuple
<
Xs
...
>&
x
,
const
Number
<
N
>&
y
)
{
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
// Tuple<Xs...> r;
// static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] * y; });
// return r;
return
generate_tuple
([
&
](
auto
i
)
{
return
x
[
i
]
*
y
;
},
Number
<
NSize
>
{});
}
// MultiIndex = scalar * MultiIndex
template
<
typename
...
Xs
,
typename
Y
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment