Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ef6933a2
Commit
ef6933a2
authored
Jan 12, 2022
by
ltqin
Browse files
merge develop and remove hacks
parent
7506342c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
98 additions
and
135 deletions
+98
-135
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3r1.hpp
.../include/tensor_operation/gridwise_gemm_xdlops_v2r3r1.hpp
+81
-114
device_operation/include/device_gemm_xdl.hpp
device_operation/include/device_gemm_xdl.hpp
+17
-21
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3r1.hpp
View file @
ef6933a2
...
...
@@ -6,9 +6,8 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "blockwise_tensor_slice_transfer
_v4r1
.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace
ck
{
...
...
@@ -140,9 +139,8 @@ template <index_t BlockSize,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
K1Value
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
@@ -150,7 +148,7 @@ template <index_t BlockSize,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K0_N_K1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
...
...
@@ -158,17 +156,10 @@ template <index_t BlockSize,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BBlockLdsExtraN
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGridStepHacks
,
typename
BGridStepHacks
,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
bool
CAccessOrderMRepeatNRepeat
,
bool
ABlockLdsExtraM
,
bool
BBlockLdsExtraN
>
index_t
CThreadTransferDstScalarPerVector
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -238,8 +229,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
static_assert
((
MPerBlock
%
(
MPerXDL
*
M
Repeat
)
==
0
)
&&
(
NPerBlock
%
(
N
Repeat
*
NPerXDL
))
==
0
,
static_assert
((
MPerBlock
%
(
MPerXDL
*
M
XdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
N
XdlPerWave
*
NPerXDL
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
...
...
@@ -336,8 +327,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
decltype
(
b_block_desc_k0_n_k1
),
MPerXDL
,
NPerXDL
,
M
Repeat
,
N
Repeat
,
M
XdlPerWave
,
N
XdlPerWave
,
K1
>
;
return
BlockwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n
);
...
...
@@ -452,11 +443,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4
r1
<
BlockSize
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -472,19 +463,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_grid_desc_k0_m_k1
,
true
>
(
a_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_k0_m_k1
,
make_multi_index
(
0
,
0
,
0
),
a_element_op
);
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}
);
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4
r1
<
BlockSize
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -500,11 +493,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_grid_desc_k0_n_k1
,
true
>
(
b_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
b_element_op
);
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}
);
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
...
...
@@ -522,8 +517,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
decltype
(
b_block_desc_k0_n_k1
),
MPerXDL
,
NPerXDL
,
M
Repeat
,
N
Repeat
,
M
XdlPerWave
,
N
XdlPerWave
,
K1
>
{};
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
...
...
@@ -541,15 +536,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
AGridStepHacks
{};
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
BGridStepHacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hack
=
AGridMoveSliceWindowStepHacks
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hack
=
BGridMoveSliceWindowStepHacks
{};
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_a_block_double
,
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
...
...
@@ -562,8 +548,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
// preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n_k1
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n_k1
,
b_grid_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m_k1
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n_k1
,
b_block_even_buf
);
...
...
@@ -579,21 +565,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
do
{
// iteration for odd
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_slice_copy_step
,
a_k0_m_k1_grid_move_slice_window_step_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n_k1
,
b_block_slice_copy_step
,
b_k0_n_k1_grid_move_slice_window_step_hack
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n_k1
,
b_block_slice_copy_step
);
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n_k1
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n_k1
,
b_grid_buf
);
// gemm even data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
...
...
@@ -602,21 +582,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n_k1
,
b_block_odd_buf
);
// iteration for even
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_slice_copy_step
,
a_k0_m_k1_grid_move_slice_window_step_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n_k1
,
b_block_slice_copy_step
,
b_k0_n_k1_grid_move_slice_window_step_hack
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n_k1
,
b_block_slice_copy_step
);
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n_k1
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n_k1
,
b_grid_buf
);
// gemm odd data
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
...
...
@@ -632,17 +606,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
// iteration for odd
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_slice_copy_step
,
a_k0_m_k1_grid_move_slice_window_step_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n_k1
,
b_block_slice_copy_step
,
b_k0_n_k1_grid_move_slice_window_step_hack
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n_k1
,
b_block_slice_copy_step
);
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n_k1
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n_k1
,
b_grid_buf
);
// gemm even data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
...
...
@@ -689,8 +659,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
const
index_t
n_thread_data_on_grid
=
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
=
CGridStepHacks
{};
const
auto
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
...
...
@@ -738,8 +706,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_grid_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
);
c_grid_buf
);
}
}
};
// namespace ck
...
...
device_operation/include/device_gemm_xdl.hpp
View file @
ef6933a2
...
...
@@ -157,7 +157,6 @@ struct DeviceGemmXdl
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
...
...
@@ -165,7 +164,7 @@ struct DeviceGemmXdl
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
B
Block
TransferThreadSliceLengths_K0_N_K1
,
A
Block
LdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
...
...
@@ -173,17 +172,10 @@ struct DeviceGemmXdl
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
// CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_k0_m_k1_grid_step_hacks
),
// AGridStepHacks,
decltype
(
b_k0_n_k1_grid_step_hacks
),
// BGridStepHacks,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
// CGridStepHacks,
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
// AGridMoveSliceWindowStepHacks,
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
// BGridMoveSliceWindowStepHacks,
false
,
// CAccessOrderMRepeatNRepeat,
ABlockLdsAddExtraM
,
BBlockLdsAddExtraN
>
;
CThreadTransferDstScalarPerVector
>
;
#else
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
...
...
@@ -224,7 +216,7 @@ struct DeviceGemmXdl
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
// CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
>
;
#endif
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -337,11 +329,12 @@ struct DeviceGemmXdl
CDataType
,
remove_reference_t
<
DeviceGemmXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
Devic
eGemm
Xdl
::
Block2CTileMap
>
,
remove_reference_t
<
typename
Gridwis
eGemm
::
Block2CTileMap
>
,
true
,
true
>
;
...
...
@@ -369,11 +362,12 @@ struct DeviceGemmXdl
CDataType
,
remove_reference_t
<
DeviceGemmXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
Devic
eGemm
Xdl
::
Block2CTileMap
>
,
remove_reference_t
<
typename
Gridwis
eGemm
::
Block2CTileMap
>
,
true
,
false
>
;
...
...
@@ -404,11 +398,12 @@ struct DeviceGemmXdl
CDataType
,
remove_reference_t
<
DeviceGemmXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
Devic
eGemm
Xdl
::
Block2CTileMap
>
,
remove_reference_t
<
typename
Gridwis
eGemm
::
Block2CTileMap
>
,
false
,
true
>
;
...
...
@@ -436,11 +431,12 @@ struct DeviceGemmXdl
CDataType
,
remove_reference_t
<
DeviceGemmXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
Devic
eGemm
Xdl
::
Block2CTileMap
>
,
remove_reference_t
<
typename
Gridwis
eGemm
::
Block2CTileMap
>
,
false
,
false
>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment