Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2d55c14c
"tests/vscode:/vscode.git/clone" did not exist on "2ea1da89ab42bde69caad909048925ca99873400"
Commit
2d55c14c
authored
Dec 21, 2022
by
Anthony Chang
Browse files
refactor Gemm2
parent
383211ef
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
357 additions
and
349 deletions
+357
-349
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+357
-349
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
2d55c14c
...
@@ -196,38 +196,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -196,38 +196,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
}
__host__
__device__
static
constexpr
auto
template
<
typename
Gemm2Param
>
Get
P
BlockDescriptor_
NBlock_NPerBlock_MBlock_MPerBlock
()
__host__
__device__
static
constexpr
auto
Get
A2
BlockDescriptor_
M0_N_M1
()
{
{
constexpr
auto
ptrans_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
return
make_naive_tensor_descriptor
(
I1
,
Number
<
VGradGemmTile_N_O_M
::
Free0_N
>
{},
I1
,
Number
<
VGradGemmTile_N_O_M
::
Sum_M
>
{}));
make_tuple
(
Number
<
Gemm2Param
::
A_M0
>
{},
Number
<
Gemm2Param
::
Free0_N
>
{},
return
ptrans_block_desc
;
Number
<
Gemm2Param
::
A_M1
>
{}),
make_tuple
(
Number
<
Gemm2Param
::
Free0_N
+
Gemm2Param
::
A_LdsPad
>
{}
*
Number
<
Gemm2Param
::
A_M1
>
{},
Number
<
Gemm2Param
::
A_M1
>
{},
I1
));
}
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
template
<
typename
Gemm2Param
>
__host__
__device__
static
constexpr
auto
GetB2BlockDescriptor_M0_O_M1
()
{
{
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
+
return
make_naive_tensor_descriptor
(
SharedMemTrait
::
b_block_space_size_aligned
)
*
make_tuple
(
Number
<
Gemm2Param
::
B_M0
>
{},
sizeof
(
DataType
);
Number
<
Gemm2Param
::
Free1_O
>
{},
const
index_t
gemm1_bytes_end
=
Number
<
Gemm2Param
::
B_M1
>
{}),
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
make_tuple
(
Number
<
Gemm2Param
::
Free1_O
+
Gemm2Param
::
B_LdsPad
>
{}
*
sizeof
(
DataType
);
Number
<
Gemm2Param
::
B_M1
>
{},
const
index_t
vgrad_gemm_bytes_end
=
(
SharedMemTrait
::
p_block_space_size_aligned
+
Number
<
Gemm2Param
::
B_M1
>
{},
SharedMemTrait
::
ygrad_block_space_size_aligned
)
*
I1
));
sizeof
(
DataType
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatGemmAcc
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
vgrad_gemm_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
@@ -485,11 +477,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -485,11 +477,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static
constexpr
auto
AThreadSliceLength_M
=
Number
<
m0
*
m1
*
m2
>
{};
static
constexpr
auto
AThreadSliceLength_M
=
Number
<
m0
*
m1
*
m2
>
{};
static
constexpr
auto
AThreadSliceLength_K1
=
Number
<
n4
>
{};
static
constexpr
auto
AThreadSliceLength_K1
=
Number
<
n4
>
{};
// A source matrix layout in AccVGPR
static
constexpr
auto
a_src_thread_desc_k0_m_k1
=
static
constexpr
auto
a_src_thread_desc_k0_m_k1
=
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1
(
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1
(
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{});
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{});
// A matrix in VGPR memory, dst of AccVGPR-to-VGPR copy
static
constexpr
auto
a_thread_desc_k0_m_k1
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
a_thread_desc_k0_m_k1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
AThreadSliceLength_K0
,
AThreadSliceLength_M
,
AThreadSliceLength_K1
));
make_tuple
(
AThreadSliceLength_K0
,
AThreadSliceLength_M
,
AThreadSliceLength_K1
));
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
...
@@ -574,68 +571,50 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -574,68 +571,50 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
};
};
// dV / dK Gemm (type 3 crr)
// dV / dK Gemm (type 3 crr)
//
TODO ANT: refactor into Gemm2
//
Describes tuning parameter for C2_n_o = A2_n_m * B2_m_o
template
<
index_t
Sum_M_
=
MPerXdl
*
2
>
template
<
index_t
Sum_M_
=
MPerXdl
*
2
>
struct
VGradGemmTile
_N_O_M_
struct
Gemm2Params
_N_O_M_
{
{
static
constexpr
index_t
Free0_N
=
NPerBlock
;
static
constexpr
index_t
Free0_N
=
NPerBlock
;
static
constexpr
index_t
Free1_O
=
Gemm1NPerBlock
;
static
constexpr
index_t
Free1_O
=
Gemm1NPerBlock
;
static
constexpr
index_t
Sum_M
=
Sum_M_
;
static
constexpr
index_t
Sum_M
=
Sum_M_
;
static
constexpr
index_t
P
_M1
=
8
;
// P will be row-major
static
constexpr
index_t
A
_M1
=
8
;
// P will be row-major
static
constexpr
index_t
P
_M0
=
Sum_M
/
P
_M1
;
static
constexpr
index_t
A
_M0
=
Sum_M
/
A
_M1
;
static
constexpr
index_t
P
_LdsPad
=
0
;
// how many multiples of M1 per N * M1 elements
static
constexpr
index_t
A
_LdsPad
=
0
;
// how many multiples of M1 per N * M1 elements
static
constexpr
index_t
YGrad
_M1
=
2
;
// dY assumed row-major, typically =2 for fp16
static
constexpr
index_t
B
_M1
=
2
;
// dY assumed row-major, typically =2 for fp16
static
constexpr
index_t
YGrad
_M0
=
Sum_M
/
YGrad
_M1
;
static
constexpr
index_t
B
_M0
=
Sum_M
/
B
_M1
;
static
constexpr
index_t
YGrad
_LdsPad
=
0
;
// how many multiples of M1 per N * M1 elements
static
constexpr
index_t
B
_LdsPad
=
0
;
// how many multiples of M1 per N * M1 elements
static_assert
(
Sum_M
%
MPerXdl
==
0
,
""
);
static_assert
(
Sum_M
%
MPerXdl
==
0
,
""
);
static
constexpr
index_t
YGrad_
SrcVectorDim
=
1
;
// Free1_O dimension
static
constexpr
index_t
B
SrcVectorDim
=
1
;
// Free1_O dimension
static
constexpr
index_t
YGrad_
SrcScalarPerVector
=
4
;
static
constexpr
index_t
B
SrcScalarPerVector
=
4
;
static
constexpr
index_t
GemmNWave
=
2
;
static
constexpr
index_t
GemmNWave
=
2
;
static
constexpr
index_t
GemmOWave
=
BlockSize
/
get_warp_size
()
/
GemmNWave
;
static
constexpr
index_t
GemmOWave
=
BlockSize
/
get_warp_size
()
/
GemmNWave
;
static
constexpr
index_t
GemmNRepeat
=
Free0_N
/
GemmNWave
/
MPerXdl
;
static
constexpr
index_t
GemmNRepeat
=
Free0_N
/
GemmNWave
/
MPerXdl
;
static
constexpr
index_t
GemmORepeat
=
Free1_O
/
GemmOWave
/
NPerXdl
;
static
constexpr
index_t
GemmORepeat
=
Free1_O
/
GemmOWave
/
NPerXdl
;
static
constexpr
index_t
GemmMPack
=
static
constexpr
index_t
GemmMPack
=
math
::
max
(
math
::
lcm
(
P
_M1
,
YGrad
_M1
),
math
::
max
(
math
::
lcm
(
A
_M1
,
B
_M1
),
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
using
YGrad_BlockSliceLengths
=
Sequence
<
YGrad_M0
,
Free1_O
,
YGrad_M1
>
;
using
BBlockSliceLengths
=
Sequence
<
B_M0
,
Free1_O
,
B_M1
>
;
using
YGrad_ThreadClusterLengths
=
using
BThreadClusterLengths
=
Sequence
<
BlockSize
/
(
Free1_O
/
YGrad_SrcScalarPerVector
),
Sequence
<
BlockSize
/
(
Free1_O
/
BSrcScalarPerVector
),
Free1_O
/
BSrcScalarPerVector
,
1
>
;
Free1_O
/
YGrad_SrcScalarPerVector
,
using
BThreadClusterArrangeOrder
=
Sequence
<
0
,
2
,
1
>
;
1
>
;
using
YGrad_ThreadClusterArrangeOrder
=
Sequence
<
0
,
2
,
1
>
;
__host__
__device__
static
constexpr
auto
GetPBlockDescriptor_M0_N_M1
()
{
constexpr
index_t
P_M0
=
Sum_M
/
P_M1
;
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
P_M0
>
{},
Number
<
Free0_N
>
{},
Number
<
P_M1
>
{}),
make_tuple
(
Number
<
Free0_N
+
P_LdsPad
>
{}
*
Number
<
P_M1
>
{},
Number
<
P_M1
>
{},
I1
));
}
__host__
__device__
static
constexpr
auto
GetYGradBlockDescriptor_M0_O_M1
()
{
constexpr
index_t
YGrad_M0
=
Sum_M
/
YGrad_M1
;
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
YGrad_M0
>
{},
Number
<
Free1_O
>
{},
Number
<
YGrad_M1
>
{}),
make_tuple
(
Number
<
Free1_O
+
YGrad_LdsPad
>
{}
*
Number
<
YGrad_M1
>
{},
Number
<
YGrad_M1
>
{},
I1
));
}
__host__
__device__
static
constexpr
auto
Get
P
BlockSliceLengths_M0_N0_M1_N1_M2_N2
()
__host__
__device__
static
constexpr
auto
Get
A
BlockSliceLengths_M0_N0_M1_N1_M2_N2
()
{
{
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr
index_t
m
=
Sum_M
-
1
;
constexpr
index_t
m
=
Gemm2Params_N_O_M
::
Sum_M
-
1
;
constexpr
index_t
m2
=
m
%
MPerXdl
;
constexpr
index_t
m2
=
m
%
MPerXdl
;
constexpr
index_t
m1
=
m
/
MPerXdl
%
Gemm0MWaves
;
constexpr
index_t
m1
=
m
/
MPerXdl
%
Gemm0MWaves
;
constexpr
index_t
m0
=
m
/
MPerXdl
/
Gemm0MWaves
%
MXdlPerWave
;
constexpr
index_t
m0
=
m
/
MPerXdl
/
Gemm0MWaves
%
MXdlPerWave
;
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr
index_t
n
=
Free0_N
-
1
;
constexpr
index_t
n
=
Gemm2Params_N_O_M
::
Free0_N
-
1
;
constexpr
index_t
n2
=
n
%
NPerXdl
;
constexpr
index_t
n2
=
n
%
NPerXdl
;
constexpr
index_t
n1
=
n
/
NPerXdl
%
Gemm0NWaves
;
constexpr
index_t
n1
=
n
/
NPerXdl
%
Gemm0NWaves
;
constexpr
index_t
n0
=
n
/
NPerXdl
/
Gemm0NWaves
%
NXdlPerWave
;
constexpr
index_t
n0
=
n
/
NPerXdl
/
Gemm0NWaves
%
NXdlPerWave
;
...
@@ -646,14 +625,212 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -646,14 +625,212 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
Sequence
<
m0
,
n0
,
m1
,
n1
,
m2
,
n2
>
{}
+
Sequence
<
1
,
1
,
1
,
1
,
1
,
1
>
{};
return
Sequence
<
m0
,
n0
,
m1
,
n1
,
m2
,
n2
>
{}
+
Sequence
<
1
,
1
,
1
,
1
,
1
,
1
>
{};
}
}
__host__
__device__
static
constexpr
auto
Get
P
BlockSliceLengths_M0_N0_M1_N1
()
__host__
__device__
static
constexpr
auto
Get
A
BlockSliceLengths_M0_N0_M1_N1
()
{
{
return
generate_sequence_v2
(
return
generate_sequence_v2
(
[](
auto
I
)
{
return
Get
P
BlockSliceLengths_M0_N0_M1_N1_M2_N2
().
At
(
I
);
},
[](
auto
I
)
{
return
Get
A
BlockSliceLengths_M0_N0_M1_N1_M2_N2
().
At
(
I
);
},
Number
<
4
>
{});
Number
<
4
>
{});
}
}
using
ABlockSliceLengths_M0_N0_M1_N1
=
decltype
(
GetABlockSliceLengths_M0_N0_M1_N1
());
};
using
Gemm2Params_N_O_M
=
Gemm2Params_N_O_M_
<>
;
// tune later
// dV / dK Gemm (type 3 crr)
template
<
typename
Gemm2Params_N_O_M
,
typename
ASrcBlockwiseGemm
>
struct
Gemm2
{
private:
static
constexpr
auto
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
ASrcBlockwiseGemm
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
static
constexpr
auto
M0
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
// repeat
static
constexpr
auto
N0
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
static
constexpr
auto
M1
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
// wave
static
constexpr
auto
N1
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
static
constexpr
auto
M2
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
// xdl
static
constexpr
auto
N2
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
static
constexpr
auto
N3
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
static
constexpr
auto
N4
=
a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
public:
// A source matrix layout in VGPR, src of VGPR-to-LDS copy
static
constexpr
auto
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
ASrcBlockwiseGemm
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_m0_n_m1
=
GetA2BlockDescriptor_M0_N_M1
<
Gemm2Params_N_O_M
>
();
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_m0_o_m1
=
GetB2BlockDescriptor_M0_O_M1
<
Gemm2Params_N_O_M
>
();
__host__
__device__
static
constexpr
auto
MakeABlockDesc_M0_N0_M1_N1_M2_N2_N3_N4
()
{
const
auto
M0_
=
a_block_desc_m0_n_m1
.
GetLength
(
I0
);
const
auto
N_
=
a_block_desc_m0_n_m1
.
GetLength
(
I1
);
const
auto
M1_
=
a_block_desc_m0_n_m1
.
GetLength
(
I2
);
const
auto
a_block_desc_m_n
=
transform_tensor_descriptor
(
a_block_desc_m0_n_m1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
M0_
,
M1_
)),
make_pass_through_transform
(
N_
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
return
transform_tensor_descriptor
(
a_block_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
M1
,
M2
)),
make_unmerge_transform
(
make_tuple
(
I1
,
N1
,
N2
,
N3
,
N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
}
// Note: we will perform sub-workgroup VGPR-to-LDS copy to save LDS space, therefore the
// destination coordinate can overlap between wavefronts in a workgroup as seen in the mod
// operation before returning the values
__host__
__device__
static
auto
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
()
{
const
auto
a_thread_origin_on_block_idx
=
ASrcBlockwiseGemm
::
CalculateCThreadOriginDataIndex8D
(
I0
,
I0
,
I0
,
I0
);
constexpr
auto
c_block_slice_lengths_m0_n0_m1_n1
=
typename
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
{};
// mrepeat, nrepeat,
// mwaves, nwaves,
return
make_tuple
(
a_thread_origin_on_block_idx
[
I0
],
// mrepeat
a_thread_origin_on_block_idx
[
I1
],
// nrepeat
a_thread_origin_on_block_idx
[
I2
]
%
c_block_slice_lengths_m0_n0_m1_n1
[
I2
],
// mwave
a_thread_origin_on_block_idx
[
I3
]
%
c_block_slice_lengths_m0_n0_m1_n1
[
I3
],
// nwave
a_thread_origin_on_block_idx
[
I4
],
// xdlops
a_thread_origin_on_block_idx
[
I5
],
a_thread_origin_on_block_idx
[
I6
],
a_thread_origin_on_block_idx
[
I7
]);
}
static
constexpr
auto
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
MakeABlockDesc_M0_N0_M1_N1_M2_N2_N3_N4
();
using
ASrcBlockSliceWindowIterator
=
SpaceFillingCurve
<
Sequence
<
M0
,
N0
,
M1
,
N1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
typename
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
,
false
>
;
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
DataType
,
decltype
(
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I0
),
// ThreadSliceLengths
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I1
),
I1
,
I1
,
I1
,
N2
,
I1
,
N4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
;
template
<
typename
GridDesc_M0_O_M1
>
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
typename
Gemm2Params_N_O_M
::
BBlockSliceLengths
,
typename
Gemm2Params_N_O_M
::
BThreadClusterLengths
,
typename
Gemm2Params_N_O_M
::
BThreadClusterArrangeOrder
,
DataType
,
DataType
,
GridDesc_M0_O_M1
,
decltype
(
b_block_desc_m0_o_m1
),
typename
Gemm2Params_N_O_M
::
BThreadClusterArrangeOrder
,
// access order == thread order
Sequence
<
1
,
0
,
2
>
,
Gemm2Params_N_O_M
::
BSrcVectorDim
,
2
,
// DstVectorDim
Gemm2Params_N_O_M
::
BSrcScalarPerVector
,
Gemm2Params_N_O_M
::
B_M1
,
1
,
1
,
true
,
true
,
1
>
;
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_m0_n_m1
),
decltype
(
b_block_desc_m0_o_m1
),
MPerXdl
,
NPerXdl
,
Gemm2Params_N_O_M
::
GemmNRepeat
,
Gemm2Params_N_O_M
::
GemmORepeat
,
Gemm2Params_N_O_M
::
GemmMPack
,
true
>
;
// TranspossC
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
Gemm2Params_N_O_M
::
B_M0
,
0
,
0
);
static
constexpr
auto
c_block_slice_copy_step
=
make_multi_index
(
Gemm2Params_N_O_M
::
GemmNRepeat
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
static
constexpr
auto
b_block_reset_copy_step
=
make_multi_index
(
-
MPerBlock
/
Gemm2Params_N_O_M
::
B_M1
,
0
,
0
);
template
<
typename
CGradDesc_N_O
>
__host__
__device__
static
const
auto
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
CGradDesc_N_O
c_grid_desc_n_o
)
{
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
const
auto
c_grid_desc_n0_o0_n1_o1_n2_o2
=
transform_tensor_descriptor
(
c_grid_desc_n_o
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
Gemm2Params_N_O_M
::
GemmNWave
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
I1
,
Gemm2Params_N_O_M
::
GemmOWave
,
NPerXdl
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
const
auto
c_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
BlockwiseGemm
{}.
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
(
c_grid_desc_n0_o0_n1_o1_n2_o2
);
return
c_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
;
}
static
constexpr
auto
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
BlockwiseGemm
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
__host__
__device__
static
const
auto
GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4
()
{
return
to_multi_index
(
BlockwiseGemm
::
CalculateCThreadOriginDataIndex8D
(
I0
,
I0
,
I0
,
I0
));
}
template
<
typename
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
>
using
CBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
DataType
,
decltype
(
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
,
tensor_operation
::
element_wise
::
PassThrough
,
// CElementwiseOperation
decltype
(
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
.
GetLengths
()),
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
// AccessOrder
7
,
// VectorDim
2
,
// ScalarPerVector
InMemoryDataOperationEnum
::
AtomicAdd
,
// GlobalMemoryDataOperation
1
,
// DstScalarStrideInVector
true
>
;
};
};
using
VGradGemmTile_N_O_M
=
VGradGemmTile_N_O_M_
<>
;
// tune later
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
struct
YDotYGrad_M_O_
...
@@ -789,10 +966,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -789,10 +966,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
p
_block_desc_m0_n_m1
=
static
constexpr
auto
a2
_block_desc_m0_n_m1
=
VGradGemmTile_N_O_M
::
Get
P
BlockDescriptor_M0_N_M1
();
Get
A2
BlockDescriptor_M0_N_M1
<
Gemm2Params_N_O_M
>
();
static
constexpr
auto
ygrad
_block_desc_m0_o_m1
=
static
constexpr
auto
b2
_block_desc_m0_o_m1
=
VGradGemmTile_N_O_M
::
GetYGrad
BlockDescriptor_M0_O_M1
();
GetB2
BlockDescriptor_M0_O_M1
<
Gemm2Params_N_O_M
>
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
DataType
)
>
{};
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
DataType
)
>
{};
...
@@ -802,16 +979,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -802,16 +979,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
p_block_space_size_aligned
=
static
constexpr
auto
p_block_space_size_aligned
=
math
::
integer_least_multiple
(
math
::
integer_least_multiple
(
p
_block_desc_m0_n_m1
.
GetElementSpaceSize
(),
max_lds_align
);
a2
_block_desc_m0_n_m1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
ygrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
static
constexpr
auto
ygrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
ygrad
_block_desc_m0_o_m1
.
GetElementSpaceSize
(),
max_lds_align
);
b2
_block_desc_m0_o_m1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
static
constexpr
auto
b1_block_space_offset
=
0
;
static
constexpr
auto
p
_block_space_offset
=
0
;
static
constexpr
auto
a2
_block_space_offset
=
0
;
static
constexpr
auto
ygrad
_block_space_offset
=
p_block_space_size_aligned
.
value
;
static
constexpr
auto
b2
_block_space_offset
=
p_block_space_size_aligned
.
value
;
// LDS allocation for reduction
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
static
constexpr
index_t
reduction_space_size_aligned
=
...
@@ -826,6 +1003,31 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -826,6 +1003,31 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
+
SharedMemTrait
::
b_block_space_size_aligned
)
*
sizeof
(
DataType
);
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
DataType
);
const
index_t
vgrad_gemm_bytes_end
=
(
SharedMemTrait
::
p_block_space_size_aligned
+
SharedMemTrait
::
ygrad_block_space_size_aligned
)
*
sizeof
(
DataType
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatGemmAcc
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
vgrad_gemm_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
}
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
typename
C0MatrixMask
,
typename
C0MatrixMask
,
...
@@ -1118,276 +1320,75 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1118,276 +1320,75 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
//
//
// set up dV / dK Gemm (type 3 crr)
// set up dV / dK Gemm (type 3 crr)
//
//
using
Gemm2
=
Gemm2
<
Gemm2Params_N_O_M
,
decltype
(
s_blockwise_gemm
)
>
;
// P vgpr to lds: writes vgprs of a subgroup to LDS and transform into m0_n_m1
// Gemm2: LDS allocation for A and B: be careful of alignment
// m0, n0 are m/n repeat per wave
auto
gemm2_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
// m1, n1 are number of waves
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a2_block_space_offset
,
constexpr
auto
p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
Gemm2
::
a_block_desc_m0_n_m1
.
GetElementSpaceSize
());
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
p_block_desc_m0_n_m1
=
VGradGemmTile_N_O_M
::
GetPBlockDescriptor_M0_N_M1
();
constexpr
auto
p_block_lengths
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
P_M0
=
p_block_lengths
[
I0
];
// repeats
constexpr
auto
P_N0
=
p_block_lengths
[
I1
];
constexpr
auto
P_M1
=
p_block_lengths
[
I2
];
// waves
constexpr
auto
P_N1
=
p_block_lengths
[
I3
];
constexpr
auto
P_M2
=
p_block_lengths
[
I4
];
// xdl
constexpr
auto
P_N2
=
p_block_lengths
[
I5
];
constexpr
auto
P_N3
=
p_block_lengths
[
I6
];
constexpr
auto
P_N4
=
p_block_lengths
[
I7
];
constexpr
auto
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
[
&
]()
constexpr
{
constexpr
auto
p_block_desc_m_n
=
transform_tensor_descriptor
(
p_block_desc_m0_n_m1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
VGradGemmTile_N_O_M
::
P_M0
,
VGradGemmTile_N_O_M
::
P_M1
)),
make_pass_through_transform
(
VGradGemmTile_N_O_M
::
Free0_N
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
auto
gemm2_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
// variable I1 there
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b2_block_space_offset
,
return
transform_tensor_descriptor
(
Gemm2
::
b_block_desc_m0_o_m1
.
GetElementSpaceSize
());
p_block_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
P_M1
,
P_M2
)),
make_unmerge_transform
(
make_tuple
(
I1
,
P_N1
,
P_N2
,
P_N3
,
P_N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
}
();
const
auto
p_thread_origin_nd_idx_on_block
=
[
&
]()
{
const
auto
c_thread_mtx_on_block
=
s_blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_adaptor
=
// dV: transform input and output tensor descriptors
make_single_stage_tensor_adaptor
(
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_o2
=
make_tuple
(
make_merge_transform
(
make_tuple
(
P_M0
,
P_M1
,
P_M2
))),
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
vgrad_grid_desc_n_o
);
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
// dV: A matrix VGPR-to-LDS blockwise copy
m_thread_data_on_block_to_m0_m1_m2_adaptor
.
CalculateBottomIndex
(
auto
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds
=
typename
Gemm2
::
ABlockwiseCopy
{
make_multi_index
(
m_thread_data_on_block
));
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
tensor_operation
::
element_wise
::
PassThrough
{}};
const
auto
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor
=
constexpr
auto
vgrad_gemm_tile_p_block_slice_window_iterator
=
make_single_stage_tensor_adaptor
(
typename
Gemm2
::
ASrcBlockSliceWindowIterator
{};
make_tuple
(
make_merge_transform
(
make_tuple
(
P_N0
,
P_N1
,
P_N2
,
P_N3
,
P_N4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
// dV: B matrix global-to-LDS blockwise copy
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor
.
CalculateBottomIndex
(
auto
vgrad_gemm_tile_ygrad_blockwise_copy
=
make_multi_index
(
n_thread_data_on_block
));
typename
Gemm2
::
template
BBlockwiseCopy
<
decltype
(
ygrad_grid_desc_m0_o_m1
)>(
ygrad_grid_desc_m0_o_m1
,
return
make_tuple
(
m_thread_data_on_block_idx
[
I0
],
// mrepeat
make_multi_index
(
m_block_data_idx_on_grid
/
Gemm2Params_N_O_M
::
B_M1
,
n_thread_data_on_block_idx
[
I0
],
// nrepeat
o_block_data_idx_on_grid
,
m_thread_data_on_block_idx
[
I1
],
// mwave
0
),
n_thread_data_on_block_idx
[
I1
],
// nwave
tensor_operation
::
element_wise
::
PassThrough
{},
m_thread_data_on_block_idx
[
I2
],
// xdlops
Gemm2
::
b_block_desc_m0_o_m1
,
n_thread_data_on_block_idx
[
I2
],
make_multi_index
(
0
,
0
,
0
),
n_thread_data_on_block_idx
[
I3
],
tensor_operation
::
element_wise
::
PassThrough
{});
n_thread_data_on_block_idx
[
I4
]);
}();
constexpr
auto
p_block_slice_lengths_m0_n0_m1_n1
=
VGradGemmTile_N_O_M
::
GetPBlockSliceLengths_M0_N0_M1_N1
();
// mrepeat, nrepeat,
// mwaves, nwaves,
// how to properly perform copy for a sub-workgroup?
auto
p_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
DataType
,
decltype
(
p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
p_block_slice_lengths_m0_n0_m1_n1
[
I0
],
// ThreadSliceLengths
p_block_slice_lengths_m0_n0_m1_n1
[
I1
],
I1
,
I1
,
I1
,
P_N2
,
I1
,
P_N4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_multi_index
(
p_thread_origin_nd_idx_on_block
[
I0
],
p_thread_origin_nd_idx_on_block
[
I1
],
p_thread_origin_nd_idx_on_block
[
I2
]
%
p_block_slice_lengths_m0_n0_m1_n1
[
I2
],
p_thread_origin_nd_idx_on_block
[
I3
]
%
p_block_slice_lengths_m0_n0_m1_n1
[
I3
],
p_thread_origin_nd_idx_on_block
[
I4
],
p_thread_origin_nd_idx_on_block
[
I5
],
p_thread_origin_nd_idx_on_block
[
I6
],
p_thread_origin_nd_idx_on_block
[
I7
]),
tensor_operation
::
element_wise
::
PassThrough
{}};
constexpr
auto
sfc_p_m0_n0_m1_n1_m2_n2
=
SpaceFillingCurve
<
Sequence
<
P_M0
,
P_N0
,
P_M1
,
P_N1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
decltype
(
p_block_slice_lengths_m0_n0_m1_n1
),
false
>
{};
constexpr
auto
ygrad_block_desc_m0_o_m1
=
// dV: blockwise gemm
VGradGemmTile_N_O_M
::
GetYGradBlockDescriptor_M0_O_M1
();
auto
vgrad_blockwise_gemm
=
typename
Gemm2
::
BlockwiseGemm
{};
auto
ygrad_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
typename
VGradGemmTile_N_O_M
::
YGrad_BlockSliceLengths
,
typename
VGradGemmTile_N_O_M
::
YGrad_ThreadClusterLengths
,
typename
VGradGemmTile_N_O_M
::
YGrad_ThreadClusterArrangeOrder
,
DataType
,
DataType
,
decltype
(
ygrad_grid_desc_m0_o_m1
),
decltype
(
ygrad_block_desc_m0_o_m1
),
typename
VGradGemmTile_N_O_M
::
YGrad_ThreadClusterArrangeOrder
,
// access order == thread
// order
Sequence
<
1
,
0
,
2
>
,
VGradGemmTile_N_O_M
::
YGrad_SrcVectorDim
,
2
,
// DstVectorDim
VGradGemmTile_N_O_M
::
YGrad_SrcScalarPerVector
,
VGradGemmTile_N_O_M
::
YGrad_M1
,
1
,
1
,
true
,
true
,
1
>
(
ygrad_grid_desc_m0_o_m1
,
make_multi_index
(
m_block_data_idx_on_grid
/
VGradGemmTile_N_O_M
::
YGrad_M1
,
o_block_data_idx_on_grid
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
ygrad_block_desc_m0_o_m1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
p_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
p_block_space_offset
,
p_block_desc_m0_n_m1
.
GetElementSpaceSize
());
auto
ygrad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
ygrad_block_space_offset
,
ygrad_block_desc_m0_o_m1
.
GetElementSpaceSize
());
auto
vgrad_blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
p_block_desc_m0_n_m1
),
decltype
(
ygrad_block_desc_m0_o_m1
),
MPerXdl
,
NPerXdl
,
VGradGemmTile_N_O_M
::
GemmNRepeat
,
VGradGemmTile_N_O_M
::
GemmORepeat
,
VGradGemmTile_N_O_M
::
GemmMPack
,
true
>
{};
// TranspossC
auto
vgrad_thread_buf
=
vgrad_blockwise_gemm
.
GetCThreadBuffer
();
auto
vgrad_thread_buf
=
vgrad_blockwise_gemm
.
GetCThreadBuffer
();
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// dV: C VGPR-to-global copy
// variable I1 there
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_o2
=
transform_tensor_descriptor
(
vgrad_grid_desc_n_o
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
VGradGemmTile_N_O_M
::
GemmNWave
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
I1
,
VGradGemmTile_N_O_M
::
GemmOWave
,
NPerXdl
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
constexpr
auto
vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
vgrad_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
vgrad_blockwise_gemm
.
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
(
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
vgrad_grid_desc_n_o
);
vgrad_grid_desc_n0_o0_n1_o1_n2_o2
);
const
auto
vgrad_thread_mtx_on_block_n_o
=
vgrad_blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
constexpr
auto
vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
decltype
(
vgrad_blockwise_gemm
)
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
VGrad_N0
=
vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4
.
GetLength
(
I0
);
constexpr
auto
VGrad_O0
=
vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4
.
GetLength
(
I1
);
constexpr
auto
VGrad_N1
=
vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4
.
GetLength
(
I2
);
constexpr
auto
VGrad_O1
=
vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4
.
GetLength
(
I3
);
constexpr
auto
VGrad_N2
=
vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4
.
GetLength
(
I4
);
constexpr
auto
VGrad_O2
=
vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4
.
GetLength
(
I5
);
constexpr
auto
VGrad_O3
=
vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4
.
GetLength
(
I6
);
constexpr
auto
VGrad_O4
=
vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4
.
GetLength
(
I7
);
const
index_t
n_thread_data_idx_on_grid
=
vgrad_thread_mtx_on_block_n_o
[
I0
];
const
index_t
o_thread_data_idx_on_grid
=
vgrad_thread_mtx_on_block_n_o
[
I1
]
+
o_block_data_idx_on_grid
;
const
auto
n_thread_data_on_grid_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
VGrad_N0
,
VGrad_N1
,
VGrad_N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n
_thread_
data_nd_idx_on_grid
=
const
auto
vgrad
_thread_
origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
=
n_thread_data_on_grid_to_n0_n
1_
n
2_
adaptor
.
CalculateBottomIndex
(
Gemm2
::
GetCThreadOriginOnBlock_N0_O0_N1_O
1_
N
2_
O2_O3_O4
()
+
make_multi_index
(
n_thread_data_idx_on_grid
)
);
make_multi_index
(
I0
,
block_work_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
const
auto
o_thread_data_on_grid_to_o0_o1_o2_o3_o4_adaptor
=
auto
vgrad_thread_copy_vgpr_to_global
=
typename
Gemm2
::
template
CBlockwiseCopy
<
decltype
(
make_single_stage_tensor_adaptor
(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
)>(
make_tuple
(
make_merge_transform
(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
make_tuple
(
VGrad_O0
,
VGrad_O1
,
VGrad_O2
,
VGrad_O3
,
VGrad_O4
))),
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
,
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
tensor_operation
::
element_wise
::
PassThrough
{});
make_tuple
(
Sequence
<
0
>
{}));
const
auto
o_thread_data_nd_idx_on_grid
=
// dK: transform input and output tensor descriptors
o_thread_data_on_grid_to_o0_o1_o2_o3_o4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
o_thread_data_idx_on_grid
));
auto
vgrad_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
DataType
,
decltype
(
vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
decltype
(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
tensor_operation
::
element_wise
::
PassThrough
,
// CElementwiseOperation
decltype
(
vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
.
GetLengths
()),
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
// AccessOrder
7
,
// VectorDim
2
,
// ScalarPerVector
InMemoryDataOperationEnum
::
AtomicAdd
,
// GlobalMemoryDataOperation
1
,
// DstScalarStrideInVector
true
>
(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
make_multi_index
(
n_thread_data_nd_idx_on_grid
[
I0
],
o_thread_data_nd_idx_on_grid
[
I0
],
n_thread_data_nd_idx_on_grid
[
I1
],
o_thread_data_nd_idx_on_grid
[
I1
],
n_thread_data_nd_idx_on_grid
[
I2
],
o_thread_data_nd_idx_on_grid
[
I2
],
o_thread_data_nd_idx_on_grid
[
I3
],
o_thread_data_nd_idx_on_grid
[
I4
]),
tensor_operation
::
element_wise
::
PassThrough
{});
// p_thread_slice_copy_step will be in for loop
constexpr
auto
ygrad_block_slice_copy_step
=
make_multi_index
(
VGradGemmTile_N_O_M
::
YGrad_M0
,
0
,
0
);
constexpr
auto
ygrad_block_reset_copy_step
=
make_multi_index
(
-
MPerBlock
/
VGradGemmTile_N_O_M
::
YGrad_M1
,
0
,
0
);
// vgrad gemm output tile
const
auto
vgrad_block_slice_copy_step
=
make_multi_index
(
VGradGemmTile_N_O_M
::
GemmNRepeat
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
//
//
// set up Y dot dY
// set up Y dot dY
//
//
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr
auto
p_block_lengths
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
P_M0
=
p_block_lengths
[
I0
];
// repeats
constexpr
auto
P_M1
=
p_block_lengths
[
I2
];
// waves
constexpr
auto
P_M2
=
p_block_lengths
[
I4
];
// xdl
constexpr
auto
y_thread_desc_m0_m1_o0_o1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
constexpr
auto
y_thread_desc_m0_m1_o0_o1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_O
));
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_O
));
constexpr
auto
y_thread_cluster_desc
=
constexpr
auto
y_thread_cluster_desc
=
...
@@ -1617,23 +1618,28 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1617,23 +1618,28 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds
();
// wait for gemm1 LDS read
block_sync_lds
();
// wait for gemm1 LDS read
constexpr
auto
p_block_slice_lengths_m0_n0_m1_n1
=
typename
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
{};
SubThreadBlock
<
BlockSize
>
p_thread_copy_subgroup
(
s_blockwise_gemm
.
GetWaveIdx
()[
I0
],
SubThreadBlock
<
BlockSize
>
p_thread_copy_subgroup
(
s_blockwise_gemm
.
GetWaveIdx
()[
I0
],
s_blockwise_gemm
.
GetWaveIdx
()[
I1
]);
s_blockwise_gemm
.
GetWaveIdx
()[
I1
]);
constexpr
index_t
num_vgrad_gemm_loop
=
MPerBlock
/
VGradGemmTile_N_O_M
::
Sum_M
;
constexpr
index_t
num_vgrad_gemm_loop
=
MPerBlock
/
Gemm2Params_N_O_M
::
Sum_M
;
static_assert
(
sfc_p_m0_n0_m1_n1_m2_n2
.
GetNumOfAccess
()
==
num_vgrad_gemm_loop
,
""
);
static_assert
(
vgrad_gemm_tile_p_block_slice_window_iterator
.
GetNumOfAccess
()
==
num_vgrad_gemm_loop
,
""
);
vgrad_thread_buf
.
Clear
();
// TODO: tune gemm2 pipeline
// TODO: tune pipeline
// dV = P^T * dY
// dV = P^T * dY
vgrad_thread_buf
.
Clear
();
static_for
<
0
,
num_vgrad_gemm_loop
,
1
>
{}([
&
](
auto
vgrad_gemm_loop_idx
)
{
// gemm dV
static_for
<
0
,
num_vgrad_gemm_loop
,
1
>
{}([
&
](
auto
vgrad_gemm_loop_idx
)
{
// gemm dV
// load VGrad Gemm B
// load VGrad Gemm B
ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_m0_o_m1
,
ygrad_grid_buf
);
vgrad_gemm_tile_ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_m0_o_m1
,
ygrad_grid_buf
);
// load VGrad Gemm A
// load VGrad Gemm A
const
auto
p_nd_idx
=
const
auto
p_nd_idx
=
sfc_p_m0_n0_m1_n1_m2_n2
.
GetIndexTupleOfNumber
(
vgrad_gemm_loop_idx
);
vgrad_gemm_tile_p_block_slice_window_iterator
.
GetIndexTupleOfNumber
(
vgrad_gemm_loop_idx
);
constexpr
auto
mwave_range
=
constexpr
auto
mwave_range
=
make_tuple
(
p_nd_idx
[
I2
],
p_nd_idx
[
I2
]
+
p_block_slice_lengths_m0_n0_m1_n1
[
I2
]);
make_tuple
(
p_nd_idx
[
I2
],
p_nd_idx
[
I2
]
+
p_block_slice_lengths_m0_n0_m1_n1
[
I2
]);
constexpr
auto
nwave_range
=
constexpr
auto
nwave_range
=
...
@@ -1641,28 +1647,29 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1641,28 +1647,29 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
if
(
p_thread_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
if
(
p_thread_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
{
{
p_thread_copy_vgpr_to_lds
.
Run
(
vgrad_gemm_tile_
p_thread_copy_vgpr_to_lds
.
Run
(
p
_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
Gemm2
::
a_src
_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
p_nd_idx
[
I0
],
p_nd_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
p_nd_idx
[
I0
],
p_nd_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
s_slash_p_thread_buf
,
s_slash_p_thread_buf
,
p
_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
Gemm2
::
a
_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
p
_block_buf
);
gemm2_a
_block_buf
);
}
}
// ygrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// ygrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// p slice window is moved by loop index
// p slice window is moved by loop index
ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_m0_o_m1
,
vgrad_gemm_tile_
ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad
_block_slice_copy_step
);
ygrad_grid_desc_m0_o_m1
,
Gemm2
::
b
_block_slice_copy_step
);
block_sync_lds
();
// sync before write
block_sync_lds
();
// sync before write
ygrad_blockwise_copy
.
RunWrite
(
ygrad_block_desc_m0_o_m1
,
ygrad_block_buf
);
vgrad_gemm_tile_ygrad_blockwise_copy
.
RunWrite
(
Gemm2
::
b_block_desc_m0_o_m1
,
gemm2_b_block_buf
);
block_sync_lds
();
// sync before read
block_sync_lds
();
// sync before read
vgrad_blockwise_gemm
.
Run
(
p
_block_buf
,
ygrad
_block_buf
,
vgrad_thread_buf
);
vgrad_blockwise_gemm
.
Run
(
gemm2_a
_block_buf
,
gemm2_b
_block_buf
,
vgrad_thread_buf
);
});
// end gemm dV
});
// end gemm dV
// atomic_add dV
// atomic_add dV
vgrad_thread_copy_vgpr_to_global
.
Run
(
vgrad
_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_thread_copy_vgpr_to_global
.
Run
(
Gemm2
::
c
_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
vgrad_thread_buf
,
vgrad_thread_buf
,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
...
@@ -1777,10 +1784,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1777,10 +1784,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
s_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
s_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_k0_n_k1
,
k_grid_desc_k0_n_k1
,
s_gemm_tile_b_block_reset_copy_step
);
// rewind K and step N
s_gemm_tile_b_block_reset_copy_step
);
// rewind K and step N
ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_m0_o_m1
,
vgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_block_reset_copy_step
);
// rewind M
ygrad_grid_desc_m0_o_m1
,
Gemm2
::
b_block_reset_copy_step
);
// rewind M
vgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
vgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad
_block_slice_copy_step
);
// step N
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
Gemm2
::
c
_block_slice_copy_step
);
// step N
pgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
pgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_o0_m_o1
,
pgrad_gemm_tile_ygrad_block_reset_copy_step
);
// rewind O
ygrad_grid_desc_o0_m_o1
,
pgrad_gemm_tile_ygrad_block_reset_copy_step
);
// rewind O
pgrad_gemm_tile_v_blockwise_copy
.
MoveSrcSliceWindow
(
pgrad_gemm_tile_v_blockwise_copy
.
MoveSrcSliceWindow
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment