Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d2640676
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "212bd0a609e27ba03a789dd925ae1dfc47be4873"
Commit
d2640676
authored
Jul 10, 2022
by
Wenkai
Browse files
add optimization for small gemm
parent
840a617d
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
715 additions
and
28 deletions
+715
-28
example/01_gemm/gemm_xdl_fp16_splitk.cpp
example/01_gemm/gemm_xdl_fp16_splitk.cpp
+15
-6
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle_small_gemm.hpp
...pu/device/device_gemm_xdl_splitk_c_shuffle_small_gemm.hpp
+665
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
+2
-2
include/ck/utility/common_header.hpp
include/ck/utility/common_header.hpp
+33
-20
No files found.
example/01_gemm/gemm_xdl_fp16_splitk.cpp
View file @
d2640676
...
...
@@ -11,7 +11,7 @@
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_splitk_c_shuffle.hpp"
#include "device_gemm_xdl_splitk_c_shuffle
_small_gemm
.hpp"
#include "device_gemm_xdl_splitk_c_shuffle_static.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
...
...
@@ -47,25 +47,34 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
#if USEING_STATIC_KERNEL
#if MNKB_1_4
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffleStatic
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//<Row, Row, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 3, 256, 16, 128, 32, 8, 2, 16, 16, 1, 2, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 8, 1, 2, S<1, 4, 1, 64>, 2>;
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
2
,
256
,
16
,
128
,
32
,
8
,
2
,
16
,
16
,
1
,
2
,
S
<
1
,
4
,
16
,
4
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
2
,
2
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
2
,
8
,
1
,
2
,
S
<
1
,
4
,
1
,
64
>
,
2
>
;
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
16
,
128
,
32
,
8
,
2
,
16
,
16
,
1
,
2
,
S
<
1
,
4
,
16
,
4
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
2
,
2
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
2
,
8
,
1
,
2
,
S
<
1
,
4
,
1
,
64
>
,
2
>
;
#else
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffleStatic
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
3
,
256
,
16
,
128
,
32
,
8
,
2
,
16
,
16
,
1
,
2
,
S
<
1
,
4
,
16
,
4
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
2
,
2
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
2
,
8
,
1
,
2
,
S
<
1
,
4
,
1
,
64
>
,
2
>
;
//<Row, Row, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 256, 16, 256, 32, 8, 2, 16, 16, 1, 4, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 8, 1, 2, S<1, 4, 1, 64>, 2>;
//<Row, Col, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 16, 128, 128, 8, 8, 16, 16, 1, 2, S<1, 16, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 2, S<1, 4, 1, 64>, 2>;
//<Row, Row, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 4, 256, 16, 128, 32, 8, 2, 16, 16, 1, 2, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 8, 1, 2, S<1, 4, 1, 64>, 2>;
#endif
#else
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffle
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffle
SmallGemm
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
3
,
256
,
16
,
128
,
32
,
8
,
2
,
16
,
16
,
1
,
2
,
S
<
1
,
4
,
16
,
4
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
2
,
2
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
2
,
8
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
2
>
;
//<Row, Row, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault,
4
, 256, 16,
128
, 32, 8, 2, 16, 16, 1,
2
, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, 1, S<1,
8
,
32
, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 8, 1,
1
, S<1,
16
, 1,
1
6>, 2>;
//<Row, Row, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault,
3
, 256, 16,
256
, 32, 8, 2, 16, 16, 1,
4
, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, 1, S<1,
4
,
64
, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 8, 1,
2
, S<1,
4
, 1, 6
4
>, 2>;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle_small_gemm.hpp
0 → 100644
View file @
d2640676
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
View file @
d2640676
...
...
@@ -14,12 +14,12 @@ struct GridwiseGemmPipeline_v2<1>
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
{
// TODO: improve applicability
return
num_loop
>
2
;
return
num_loop
>
=
2
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
(
num_loop
/
2
)
>
1
;
return
num_loop
>
2
;
}
template
<
bool
HasMainLoop
,
...
...
include/ck/utility/common_header.hpp
View file @
d2640676
...
...
@@ -48,31 +48,44 @@
#include "amd_xdlops.hpp"
#endif
#define USEING_STATIC_KERNEL
1
#define USEING_STATIC_KERNEL
0
#define MNKB_
16_1152_512
0_8
0
#define MNKB_1
6_5120_384_3 1
#define MNKB_
16_1280_5120
_8
1
#define MNKB_
16_5120_1280_5 1
#define MNKB_0_8
1
#define MNKB_1
_4 0
#define MNKB_
2
_8
0
#define MNKB_
3_5 0
#if MNKB_16_1152_5120_8
#define MNKB_4_5 0
#define MNKB_5_5 0
#if MNKB_0_8
#define M_matrix 16
#define N_matrix 4096
#define K_matrix 12800
#define K_batch 5
#elif MNKB_1_4
#define M_matrix 16
#define N_matrix 4096
#define K_matrix 12800
#define K_batch 5
#elif MNKB_2_8
#define M_matrix 16
#define N_matrix
1152
#define K_matrix
5
120
#define K_batch
8
#elif MNKB_
16_5120_384_3
#define N_matrix
4096
#define K_matrix 12
80
0
#define K_batch
5
#elif MNKB_
3_5
#define M_matrix 16
#define N_matrix
5120
#define K_matrix
384
#define K_batch
4
#elif MNKB_
16_1280_5120_8
#define N_matrix
4096
#define K_matrix
12800
#define K_batch
5
#elif MNKB_
4_5
#define M_matrix 16
#define N_matrix
1280
#define K_matrix
5
120
#define K_batch
8
#elif MNKB_
16_5120_1280
_5
#define N_matrix
4096
#define K_matrix 12
80
0
#define K_batch
5
#elif MNKB_
5
_5
#define M_matrix 16
#define N_matrix
5120
#define K_matrix 1280
#define N_matrix
4096
#define K_matrix 1280
0
#define K_batch 5
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment