Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
303b1a86
Commit
303b1a86
authored
Dec 27, 2021
by
ltqin
Browse files
remove step hack
parent
aaa89914
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
81 deletions
+11
-81
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
+5
-10
device_operation/include/device_gemm_splitk_xdl.hpp
device_operation/include/device_gemm_splitk_xdl.hpp
+6
-71
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
View file @
303b1a86
...
...
@@ -136,7 +136,6 @@ template <index_t BlockSize,
index_t
K1Value
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
@@ -144,7 +143,7 @@ template <index_t BlockSize,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K0_N_K1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
...
...
@@ -152,12 +151,10 @@ template <index_t BlockSize,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BBlockLdsExtraN
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
bool
CAccessOrderMRepeatNRepeat
,
bool
ABlockLdsExtraM
,
bool
BBlockLdsExtraN
>
index_t
CThreadTransferDstScalarPerVector
>
struct
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -477,7 +474,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -508,7 +504,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -604,8 +599,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
a_blockwise_copy
.
RunWrite
(
a_b_k0_m_k1_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_b_k0_n_k1_block_desc
,
b_block_buf
);
k_block_data_begin
+=
K0PerBlock
;
}
while
(
k_block_data_begin
<
(
K0
-
K0PerBlock
));
k
0
_block_data_begin
+=
K0PerBlock
;
}
while
(
k
0
_block_data_begin
<
(
K0
-
K0PerBlock
));
}
// tail
...
...
device_operation/include/device_gemm_splitk_xdl.hpp
View file @
303b1a86
...
...
@@ -159,51 +159,6 @@ struct DeviceGemmSplitKXdl
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_KBatch_K0_N_K1
(
1
,
1
,
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// TODO remove these hacks
static
constexpr
auto
a_kbatch_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: Kbatch
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: K0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
// 3+: K1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0-: Kbatch
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1-: K0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: K1
static
constexpr
auto
b_kbatch_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: Kbatch
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0-: Kbatch
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1-: K0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 2-: N
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: K1
static
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
static
constexpr
auto
a_kbatch_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
static
constexpr
auto
b_kbatch_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
<
BlockSize
,
...
...
@@ -225,7 +180,6 @@ struct DeviceGemmSplitKXdl
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
...
...
@@ -233,7 +187,7 @@ struct DeviceGemmSplitKXdl
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
B
Block
TransferThreadSliceLengths_K0_N_K1
,
A
Block
LdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
...
...
@@ -241,19 +195,10 @@ struct DeviceGemmSplitKXdl
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
// CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_kbatch_k0_m_k1_grid_step_hacks
),
// AGridStepHacks,
decltype
(
b_kbatch_k0_n_k1_grid_step_hacks
),
// BGridStepHacks,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
// CGridStepHacks,
decltype
(
a_kbatch_k0_m_k1_grid_move_slice_window_step_hacks
),
// AGridMoveSliceWindowStepHacks,
decltype
(
b_kbatch_k0_n_k1_grid_move_slice_window_step_hacks
),
// BGridMoveSliceWindowStepHacks,
false
,
// CAccessOrderMRepeatNRepeat,
ABlockLdsAddExtraM
,
BBlockLdsAddExtraN
>
;
CThreadTransferDstScalarPerVector
>
;
// GridwiseGemm
using
GridwiseGemmAtomicAdd
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
<
...
...
@@ -276,7 +221,6 @@ struct DeviceGemmSplitKXdl
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
...
...
@@ -284,7 +228,7 @@ struct DeviceGemmSplitKXdl
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
B
Block
TransferThreadSliceLengths_K0_N_K1
,
A
Block
LdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
...
...
@@ -292,19 +236,10 @@ struct DeviceGemmSplitKXdl
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
// CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_kbatch_k0_m_k1_grid_step_hacks
),
// AGridStepHacks,
decltype
(
b_kbatch_k0_n_k1_grid_step_hacks
),
// BGridStepHacks,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
// CGridStepHacks,
decltype
(
a_kbatch_k0_m_k1_grid_move_slice_window_step_hacks
),
// AGridMoveSliceWindowStepHacks,
decltype
(
b_kbatch_k0_n_k1_grid_move_slice_window_step_hacks
),
// BGridMoveSliceWindowStepHacks,
false
,
// CAccessOrderMRepeatNRepeat,
ABlockLdsAddExtraM
,
BBlockLdsAddExtraN
>
;
CThreadTransferDstScalarPerVector
>
;
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CGridDesc_M_N
{}));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment