Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
20718690
Commit
20718690
authored
Apr 14, 2022
by
Chao Liu
Browse files
adding gemm pipeline
parent
18707866
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
172 additions
and
155 deletions
+172
-155
example/01_gemm/gemm_xdl_fp16.cpp
example/01_gemm/gemm_xdl_fp16.cpp
+34
-9
include/ck/config.hpp
include/ck/config.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
+39
-40
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
+5
-12
profiler/CMakeLists.txt
profiler/CMakeLists.txt
+29
-29
profiler/src/profiler.cpp
profiler/src/profiler.cpp
+64
-64
No files found.
example/01_gemm/gemm_xdl_fp16.cpp
View file @
20718690
...
@@ -11,9 +11,10 @@
...
@@ -11,9 +11,10 @@
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
//#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
//#include "device_gemm_xdl_c_shuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
//#include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_cshuffle_v2.hpp"
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "gemm_specialization.hpp"
...
@@ -42,15 +43,39 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
...
@@ -42,15 +43,39 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
// clang-format off
// clang-format off
#if 0
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
//< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// // 1-stage prefetch
< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// // 2-stage prefetch
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
#elif
1
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle_v2
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
8
>
,
8
>
;
// < Row, Col, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>;
#elif 1
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdl
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
;
// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
#endif
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
include/ck/config.hpp
View file @
20718690
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#define CK_USE_LAUNCH_BOUNDS 1
#define CK_USE_LAUNCH_BOUNDS 1
#ifdef CK_USE_LAUNCH_BOUNDS
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK
256
#define CK_MAX_THREAD_PER_BLOCK
512
#define CK_MIN_BLOCK_PER_CU 1
#define CK_MIN_BLOCK_PER_CU 1
#endif
#endif
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
View file @
20718690
...
@@ -71,35 +71,35 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
...
@@ -71,35 +71,35 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
static
__device__
void
RunABBlockTransferPipeline
(
const
AGridDesc
&
a_grid_desc
,
static
__device__
void
RunABBlockTransferPipeline
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_block
wise
_copy
,
ABlockTransfer
&
a_block_copy
,
const
AGridBuffer
&
a_grid_buf
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_block
wise
_copy
,
BBlockTransfer
&
b_block_copy
,
const
BGridBuffer
&
b_grid_buf
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BBlockTransferStep
&
b_block_copy_step
,
index_t
num_loop
)
index_t
num_loop
)
{
{
// global read 0
// global read 0
a_block
wise
_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_block_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_block
wise
_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
b_block_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// move to 1
// move to 1
a_block
wise
_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_block_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_block
wise
_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write 0
// LDS write 0
a_block
wise
_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_block_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// global Read 1
// global Read 1
a_block
wise
_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_block_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
// LDS write 0
// LDS write 0
b_block
wise
_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
b_block_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// global Read 1
// global Read 1
b_block
wise
_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
b_block_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// main body
// main body
// FIXME: HasMainLoop = (num_loop) > 2
// FIXME: HasMainLoop = (num_loop) > 2
...
@@ -116,18 +116,18 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
...
@@ -116,18 +116,18 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
block_sync_lds
();
block_sync_lds
();
// move to i + 2
// move to i + 2
a_block
wise
_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_block_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_block
wise
_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write i + 1
// LDS write i + 1
a_block
wise
_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_block_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// global read i + 2
// global read i + 2
a_block
wise
_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_block_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
// LDS write i + 1
// LDS write i + 1
b_block
wise
_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
b_block_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// global read i + 2
// global read i + 2
b_block
wise
_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
b_block_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
++
i
;
++
i
;
}
while
(
i
<
(
num_loop
-
2
));
}
while
(
i
<
(
num_loop
-
2
));
...
@@ -142,8 +142,8 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
...
@@ -142,8 +142,8 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
block_sync_lds
();
block_sync_lds
();
// LDS write num_loop - 1
// LDS write num_loop - 1
a_block
wise
_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_block_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_block
wise
_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
b_block_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
block_sync_lds
();
block_sync_lds
();
...
@@ -153,7 +153,7 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
...
@@ -153,7 +153,7 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
static
__device__
void
RunBlockGemmPipeline
(
ABlockBuffer
&
a_block_buf
,
static
__device__
void
RunBlockGemmPipeline
(
ABlockBuffer
&
a_block_buf
,
BBlockBuffer
&
b_block_buf
,
BBlockBuffer
&
b_block_buf
,
const
BlockwiseGemm
&
block
wise
_gemm
,
const
BlockwiseGemm
&
block_gemm
,
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
index_t
num_loop
)
{
{
...
@@ -171,7 +171,7 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
...
@@ -171,7 +171,7 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
block_sync_lds
();
block_sync_lds
();
// GEMM i
// GEMM i
block
wise
_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
block_sync_lds
();
...
@@ -192,7 +192,7 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
...
@@ -192,7 +192,7 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
block_sync_lds
();
block_sync_lds
();
// GEMM num_loop - 2
// GEMM num_loop - 2
block
wise
_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
block_sync_lds
();
...
@@ -201,46 +201,45 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
...
@@ -201,46 +201,45 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
block_sync_lds
();
block_sync_lds
();
// GEMM num_loop - 1
// GEMM num_loop - 1
block
wise
_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
}
}
static
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
static
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_block
wise
_copy
,
ABlockTransfer
&
a_block_copy
,
const
AGridBuffer
&
a_grid_buf
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_block
wise
_copy
,
BBlockTransfer
&
b_block_copy
,
const
BGridBuffer
&
b_grid_buf
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
block
wise
_gemm
,
const
BlockwiseGemm
&
block_gemm
,
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
index_t
num_loop
)
{
{
if
(
ABBlockTransferThreadGroup
::
IsBelong
())
if
(
ABBlockTransferThreadGroup
::
IsBelong
())
{
{
gridwise_gemm_pipeline
.
RunABBlockTransferPipeline
(
a_grid_desc
_ak0_m_ak1
,
RunABBlockTransferPipeline
(
a_grid_desc
,
a_block_desc
_ak0_m_ak1
,
a_block_desc
,
a_block
wise
_copy
,
a_block_copy
,
a_grid_buf
,
a_grid_buf
,
a_block_buf
,
a_block_buf
,
a_block_
slice_
copy_step
,
a_block_copy_step
,
b_grid_desc
_bk0_n_bk1
,
b_grid_desc
,
b_block_desc
_bk0_n_bk1
,
b_block_desc
,
b_block
wise
_copy
,
b_block_copy
,
b_grid_buf
,
b_grid_buf
,
b_block_buf
,
b_block_buf
,
b_block_
slice_
copy_step
,
b_block_copy_step
,
num_loop
);
num_loop
);
}
}
else
if
(
BlockGemmThreadGroup
::
IsBelong
())
else
if
(
BlockGemmThreadGroup
::
IsBelong
())
{
{
gridwise_gemm_pipeline
.
RunBlockGemmPipeline
(
RunBlockGemmPipeline
(
a_block_buf
,
b_block_buf
,
block_gemm
,
c_thread_buf
,
num_loop
);
a_block_buf
,
b_block_buf
,
blockwise_gemm
,
c_thread_buf
,
num_loop
);
}
}
}
}
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
View file @
20718690
...
@@ -4,12 +4,11 @@
...
@@ -4,12 +4,11 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r1.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "gridwise_gemm_pipeline_v2.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -118,11 +117,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -118,11 +117,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
using
ThisThreadBlock
=
using
ThisThreadBlock
=
AnyThreadBlock
<
ABBlockTransferThreadGroupSize
+
BlockGemmThreadGroupSize
>
;
AnyThreadBlock
<
ABBlockTransferThreadGroupSize
+
BlockGemmThreadGroupSize
>
;
#if 1
using
ABBlockTransferThreadGroup
=
ThisThreadBlock
;
using
BlockGemmThreadGroup
=
ThisThreadBlock
;
using
CShuffleBlockTransferThreadGroup
=
ThisThreadBlock
;
#else
struct
ABBlockTransferThreadGroup
struct
ABBlockTransferThreadGroup
{
{
__device__
static
constexpr
index_t
GetNumOfThread
()
__device__
static
constexpr
index_t
GetNumOfThread
()
...
@@ -157,7 +151,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -157,7 +151,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
};
};
using
CShuffleBlockTransferThreadGroup
=
ThisThreadBlock
;
using
CShuffleBlockTransferThreadGroup
=
ThisThreadBlock
;
#endif
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
{
...
@@ -494,7 +487,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -494,7 +487,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
KPerBlock
);
#if
1
#if
0
// gridwise GEMM pipeline
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>,
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>,
...
@@ -667,9 +660,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -667,9 +660,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
// shuffle: blockwise copy C from LDS to global
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
CShuffleBlockTransferThreadGroup
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
1
,
...
...
profiler/CMakeLists.txt
View file @
20718690
...
@@ -24,38 +24,38 @@ include_directories(BEFORE
...
@@ -24,38 +24,38 @@ include_directories(BEFORE
set
(
PROFILER_SOURCE
set
(
PROFILER_SOURCE
src/profiler.cpp
src/profiler.cpp
src/profile_gemm.cpp
src/profile_gemm.cpp
src/profile_gemm_bias_2d.cpp
#
src/profile_gemm_bias_2d.cpp
src/profile_gemm_bias_relu.cpp
#
src/profile_gemm_bias_relu.cpp
src/profile_gemm_bias_relu_add.cpp
#
src/profile_gemm_bias_relu_add.cpp
src/profile_gemm_reduce.cpp
#
src/profile_gemm_reduce.cpp
src/profile_batched_gemm.cpp
#
src/profile_batched_gemm.cpp
src/profile_conv_fwd.cpp
#
src/profile_conv_fwd.cpp
src/profile_conv_fwd_bias_relu.cpp
#
src/profile_conv_fwd_bias_relu.cpp
src/profile_conv_fwd_bias_relu_add.cpp
#
src/profile_conv_fwd_bias_relu_add.cpp
src/profile_conv_fwd_bias_relu_atomic_add.cpp
#
src/profile_conv_fwd_bias_relu_atomic_add.cpp
src/profile_convnd_bwd_data.cpp
#
src/profile_convnd_bwd_data.cpp
src/profile_reduce.cpp
#
src/profile_reduce.cpp
src/profile_grouped_gemm.cpp
#
src/profile_grouped_gemm.cpp
src/profile_conv_bwd_weight.cpp
#
src/profile_conv_bwd_weight.cpp
src/profile_batched_gemm_reduce.cpp
#
src/profile_batched_gemm_reduce.cpp
)
)
add_executable
(
ckProfiler
${
PROFILER_SOURCE
}
)
add_executable
(
ckProfiler
${
PROFILER_SOURCE
}
)
target_link_libraries
(
ckProfiler PRIVATE host_tensor
)
target_link_libraries
(
ckProfiler PRIVATE host_tensor
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_reduce_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_bias2d_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_bias_relu_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_bias_relu_add_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance)
target_link_libraries
(
ckProfiler PRIVATE device_batched_gemm_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_fwd_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance)
target_link_libraries
(
ckProfiler PRIVATE device_convnd_bwd_data_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
target_link_libraries
(
ckProfiler PRIVATE device_reduce_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
target_link_libraries
(
ckProfiler PRIVATE device_reduce_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
target_link_libraries
(
ckProfiler PRIVATE device_grouped_gemm_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_bwd_weight_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance)
target_link_libraries
(
ckProfiler PRIVATE device_batched_gemm_reduce_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
profiler/src/profiler.cpp
View file @
20718690
...
@@ -26,70 +26,70 @@ int main(int argc, char* argv[])
...
@@ -26,70 +26,70 @@ int main(int argc, char* argv[])
{
{
return
profile_gemm
(
argc
,
argv
);
return
profile_gemm
(
argc
,
argv
);
}
}
else
if
(
strcmp
(
argv
[
1
],
"gemm_bias_2d"
)
==
0
)
//
else if(strcmp(argv[1], "gemm_bias_2d") == 0)
{
//
{
return
profile_gemm_bias_2d
(
argc
,
argv
);
//
return profile_gemm_bias_2d(argc, argv);
}
//
}
else
if
(
strcmp
(
argv
[
1
],
"gemm_bias_relu"
)
==
0
)
//
else if(strcmp(argv[1], "gemm_bias_relu") == 0)
{
//
{
return
profile_gemm_bias_relu
(
argc
,
argv
);
//
return profile_gemm_bias_relu(argc, argv);
}
//
}
else
if
(
strcmp
(
argv
[
1
],
"gemm_bias_relu_add"
)
==
0
)
//
else if(strcmp(argv[1], "gemm_bias_relu_add") == 0)
{
//
{
return
profile_gemm_bias_relu_add
(
argc
,
argv
);
//
return profile_gemm_bias_relu_add(argc, argv);
}
//
}
else
if
(
strcmp
(
argv
[
1
],
"gemm_reduce"
)
==
0
)
//
else if(strcmp(argv[1], "gemm_reduce") == 0)
{
//
{
return
profile_gemm_reduce
(
argc
,
argv
);
//
return profile_gemm_reduce(argc, argv);
}
//
}
else
if
(
strcmp
(
argv
[
1
],
"batched_gemm"
)
==
0
)
//
else if(strcmp(argv[1], "batched_gemm") == 0)
{
//
{
return
profile_batched_gemm
(
argc
,
argv
);
//
return profile_batched_gemm(argc, argv);
}
//
}
else
if
(
strcmp
(
argv
[
1
],
"batched_gemm_reduce"
)
==
0
)
//
else if(strcmp(argv[1], "batched_gemm_reduce") == 0)
{
//
{
return
profile_batched_gemm_reduce
(
argc
,
argv
);
//
return profile_batched_gemm_reduce(argc, argv);
}
//
}
else
if
(
strcmp
(
argv
[
1
],
"grouped_gemm"
)
==
0
)
//
else if(strcmp(argv[1], "grouped_gemm") == 0)
{
//
{
profile_grouped_gemm
(
argc
,
argv
);
//
profile_grouped_gemm(argc, argv);
}
//
}
else
if
(
strcmp
(
argv
[
1
],
"conv_fwd"
)
==
0
)
//
else if(strcmp(argv[1], "conv_fwd") == 0)
{
//
{
return
profile_conv_fwd
(
argc
,
argv
);
//
return profile_conv_fwd(argc, argv);
}
//
}
else
if
(
strcmp
(
argv
[
1
],
"conv_fwd_bias_relu"
)
==
0
)
//
else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0)
{
//
{
return
profile_conv_fwd_bias_relu
(
argc
,
argv
);
//
return profile_conv_fwd_bias_relu(argc, argv);
}
//
}
else
if
(
strcmp
(
argv
[
1
],
"conv_fwd_bias_relu_add"
)
==
0
)
//
else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0)
{
//
{
return
profile_conv_fwd_bias_relu_add
(
argc
,
argv
);
//
return profile_conv_fwd_bias_relu_add(argc, argv);
}
//
}
else
if
(
strcmp
(
argv
[
1
],
"conv_fwd_bias_relu_atomic_add"
)
==
0
)
//
else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0)
{
//
{
return
profile_conv_fwd_bias_relu_atomic_add
(
argc
,
argv
);
//
return profile_conv_fwd_bias_relu_atomic_add(argc, argv);
}
//
}
else
if
(
strcmp
(
argv
[
1
],
"conv1d_bwd_data"
)
==
0
)
//
else if(strcmp(argv[1], "conv1d_bwd_data") == 0)
{
//
{
return
profile_convnd_bwd_data
(
argc
,
argv
,
1
);
//
return profile_convnd_bwd_data(argc, argv, 1);
}
//
}
else
if
(
strcmp
(
argv
[
1
],
"conv2d_bwd_data"
)
==
0
)
//
else if(strcmp(argv[1], "conv2d_bwd_data") == 0)
{
//
{
return
profile_convnd_bwd_data
(
argc
,
argv
,
2
);
//
return profile_convnd_bwd_data(argc, argv, 2);
}
//
}
else
if
(
strcmp
(
argv
[
1
],
"conv3d_bwd_data"
)
==
0
)
//
else if(strcmp(argv[1], "conv3d_bwd_data") == 0)
{
//
{
return
profile_convnd_bwd_data
(
argc
,
argv
,
3
);
//
return profile_convnd_bwd_data(argc, argv, 3);
}
//
}
else
if
(
strcmp
(
argv
[
1
],
"reduce"
)
==
0
)
//
else if(strcmp(argv[1], "reduce") == 0)
{
//
{
return
profile_reduce
(
argc
,
argv
);
//
return profile_reduce(argc, argv);
}
//
}
else
if
(
strcmp
(
argv
[
1
],
"conv2d_bwd_weight"
)
==
0
)
//
else if(strcmp(argv[1], "conv2d_bwd_weight") == 0)
{
//
{
return
profile_conv_bwd_weight
(
argc
,
argv
);
//
return profile_conv_bwd_weight(argc, argv);
}
//
}
else
else
{
{
// clang-format off
// clang-format off
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment