Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5b3bd032
Commit
5b3bd032
authored
Apr 22, 2022
by
Chao Liu
Browse files
refactor
parent
76ee0baf
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
340 additions
and
302 deletions
+340
-302
example/01_gemm/gemm_xdl_fp16.cpp
example/01_gemm/gemm_xdl_fp16.cpp
+5
-15
include/ck/tensor_operation/gpu/device/device_gemm_xdl_producer_consumer_cshuffle.hpp
...gpu/device/device_gemm_xdl_producer_consumer_cshuffle.hpp
+8
-8
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_producer_consumer.hpp
...ion/gpu/grid/gridwise_gemm_pipeline_producer_consumer.hpp
+234
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+0
-75
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
+39
-153
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+5
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_producer_consumer_cshuffle.hpp
...gpu/grid/gridwise_gemm_xdl_producer_consumer_cshuffle.hpp
+32
-33
library/src/utility/conv_fwd_util.cpp
library/src/utility/conv_fwd_util.cpp
+17
-18
No files found.
example/01_gemm/gemm_xdl_fp16.cpp
View file @
5b3bd032
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_cshuffle
_v2
.hpp"
#include "device_gemm_xdl_
producer_consumer_
cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "gemm_specialization.hpp"
...
@@ -57,24 +57,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
...
@@ -57,24 +57,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// // 2-stage prefetch
// // 2-stage prefetch
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
#elif 0
#elif 1
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle_v2
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_ProducerConsumer_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// all thread
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
0
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 0, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
#elif 0
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle_v2
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// producer & consumer
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>;
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
8
>
,
8
>
;
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
8
>
,
8
>
;
// < Row, Col, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>;
#elif 1
#elif 1
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdl
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdl
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle
_v2
.hpp
→
include/ck/tensor_operation/gpu/device/device_gemm_xdl_
producer_consumer_
cshuffle.hpp
View file @
5b3bd032
...
@@ -7,8 +7,8 @@
...
@@ -7,8 +7,8 @@
#include "tensor_layout.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdl_cshuffle
_v2
.hpp"
#include "gridwise_gemm_xdl_
producer_consumer_
cshuffle.hpp"
#include "
tensor_operation/gpu/device/
gemm_specialization.hpp"
#include "gemm_specialization.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -56,10 +56,10 @@ template <typename ALayout,
...
@@ -56,10 +56,10 @@ template <typename ALayout,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
>
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
>
struct
DeviceGemm_Xdl_CShuffle
_v2
struct
DeviceGemm_Xdl_
ProducerConsumer_
CShuffle
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceGemm_Xdl_CShuffle
_v2
;
using
DeviceOp
=
DeviceGemm_Xdl_
ProducerConsumer_
CShuffle
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -334,7 +334,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
...
@@ -334,7 +334,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle
_v2
<
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdl_
producer_consumer_
cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -471,7 +471,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
...
@@ -471,7 +471,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle
_v2
<
const
auto
kernel
=
kernel_gemm_xdl_
producer_consumer_
cshuffle
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
...
@@ -523,7 +523,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
...
@@ -523,7 +523,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle
_v2
<
const
auto
kernel
=
kernel_gemm_xdl_
producer_consumer_
cshuffle
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
...
@@ -672,7 +672,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
...
@@ -672,7 +672,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceGemm_Xdl_CShuffle
_v2
"
str
<<
"DeviceGemm_Xdl_
ProducerConsumer_
CShuffle"
<<
"<"
<<
"<"
<<
ABBlockTransferThreadGroupSize
<<
", "
<<
ABBlockTransferThreadGroupSize
<<
", "
<<
BlockGemmThreadGroupSize
<<
", "
<<
BlockGemmThreadGroupSize
<<
", "
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_producer_consumer.hpp
0 → 100644
View file @
5b3bd032
#pragma once
#include "common_header.hpp"
namespace
ck
{
template
<
typename
ABBlockTransferThreadGroup
,
typename
BlockGemmThreadGroup
,
index_t
NumGemmKPrefetchStage
>
struct
GridwiseGemmPipelineProducerConsumer
;
// 1-stage prefetch
template
<
typename
ABBlockTransferThreadGroup
,
typename
BlockGemmThreadGroup
>
struct
GridwiseGemmPipelineProducerConsumer
<
ABBlockTransferThreadGroup
,
BlockGemmThreadGroup
,
1
>
{
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
{
// TODO: improve applicability
return
num_loop
%
2
==
0
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
/
2
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
>
static
__device__
void
RunABBlockTransferPipeline
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_block_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_block_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
index_t
num_loop
)
{
// global read 0
a_block_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_block_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// move to 1
a_block_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write 0
a_block_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// global Read 1
a_block_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
// LDS write 0
b_block_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// global Read 1
b_block_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
block_sync_lds
();
// GEMM i
block_sync_lds
();
// move to i + 2
a_block_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write i + 1
a_block_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// global read i + 2
a_block_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
// LDS write i + 1
b_block_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// global read i + 2
b_block_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
++
i
;
}
while
(
i
<
(
num_loop
-
2
));
}
// tail
{
block_sync_lds
();
// GEMM num_loop - 2
block_sync_lds
();
// LDS write num_loop - 1
a_block_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_block_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
block_sync_lds
();
// GEMM num_loop - 1
}
}
template
<
bool
HasMainLoop
,
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
static
__device__
void
RunBlockGemmPipeline
(
ABlockBuffer
&
a_block_buf
,
BBlockBuffer
&
b_block_buf
,
const
BlockwiseGemm
&
block_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
block_sync_lds
();
// GEMM i
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
// move to i + 2
// LDS write i + 1
// global read i + 2
// LDS write i + 1
// global read i + 2
++
i
;
}
while
(
i
<
(
num_loop
-
2
));
}
// tail
{
block_sync_lds
();
// GEMM num_loop - 2
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
// LDS write num_loop - 1
block_sync_lds
();
// GEMM num_loop - 1
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
static
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_block_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_block_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
block_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
if
(
ABBlockTransferThreadGroup
::
IsBelong
())
{
RunABBlockTransferPipeline
<
HasMainLoop
>
(
a_grid_desc
,
a_block_desc
,
a_block_copy
,
a_grid_buf
,
a_block_buf
,
a_block_copy_step
,
b_grid_desc
,
b_block_desc
,
b_block_copy
,
b_grid_buf
,
b_block_buf
,
b_block_copy_step
,
num_loop
);
}
else
if
(
BlockGemmThreadGroup
::
IsBelong
())
{
RunBlockGemmPipeline
<
HasMainLoop
>
(
a_block_buf
,
b_block_buf
,
block_gemm
,
c_thread_buf
,
num_loop
);
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
5b3bd032
...
@@ -51,7 +51,6 @@ struct GridwiseGemmPipeline_v1<1>
...
@@ -51,7 +51,6 @@ struct GridwiseGemmPipeline_v1<1>
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
index_t
num_loop
)
{
{
#if 1
// preload data into LDS
// preload data into LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
...
@@ -98,80 +97,6 @@ struct GridwiseGemmPipeline_v1<1>
...
@@ -98,80 +97,6 @@ struct GridwiseGemmPipeline_v1<1>
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
#elif 1
// global read 0
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// move to 1
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// LDS write 0
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// global Read 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
// LDS write 0
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// global Read 1
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// main body
// FIXME: HasMainLoop = (num_loop) > 2
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
block_sync_lds
();
// GEMM i
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
// move to i + 2
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write i + 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// global read i + 2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
// LDS write i + 1
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// global read i + 2
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
++
i
;
}
while
(
i
<
(
num_loop
-
2
));
}
// tail
{
block_sync_lds
();
// GEMM num_loop - 2
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
// LDS write num_loop - 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
block_sync_lds
();
// GEMM num_loop - 1
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
#endif
}
}
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
View file @
5b3bd032
#pragma once
#pragma once
#include "common_header.hpp"
#include "common_header.hpp"
namespace
ck
{
namespace
ck
{
template
<
typename
ABBlockTransferThreadGroup
,
typename
BlockGemmThreadGroup
>
struct
GridwiseGemmPipeline_v2
struct
GridwiseGemmPipeline_v2
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__device__
constexpr
GridwiseGemmPipeline_v2
()
{
// TODO static assert
}
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
{
{
// TODO: improve applicability
// TODO: improve applicability
...
@@ -23,7 +13,7 @@ struct GridwiseGemmPipeline_v2
...
@@ -23,7 +13,7 @@ struct GridwiseGemmPipeline_v2
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
{
return
num_loop
/
2
>
1
;
return
(
num_loop
/
2
)
>
1
;
}
}
template
<
bool
HasMainLoop
,
template
<
bool
HasMainLoop
,
...
@@ -39,41 +29,46 @@ struct GridwiseGemmPipeline_v2
...
@@ -39,41 +29,46 @@ struct GridwiseGemmPipeline_v2
typename
BGridBuffer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
typename
CThreadBuffer
>
static
__device__
void
RunABBlockTransferPipeline
(
const
AGridDesc
&
a_grid_desc
,
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_block_copy
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_block_copy
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BBlockTransferStep
&
b_block_copy_step
,
index_t
num_loop
)
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
{
// global read 0
// global read 0
a_block_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_block
wise
_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_block_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
b_block
wise
_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// move to 1
// move to 1
a_block_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// LDS write 0
// LDS write 0
a_block_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_block
wise
_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// global Read 1
// global Read 1
a_block_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_block
wise
_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
// LDS write 0
// LDS write 0
b_block_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
b_block
wise
_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// global Read 1
// global Read 1
b_block_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
b_block
wise
_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// main body
// main body
// FIXME: HasMainLoop = (num_loop) > 2
if
constexpr
(
HasMainLoop
)
if
constexpr
(
HasMainLoop
)
{
{
index_t
i
=
0
;
index_t
i
=
0
;
...
@@ -83,22 +78,23 @@ struct GridwiseGemmPipeline_v2
...
@@ -83,22 +78,23 @@ struct GridwiseGemmPipeline_v2
block_sync_lds
();
block_sync_lds
();
// GEMM i
// GEMM i
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
block_sync_lds
();
// move to i + 2
// move to i + 2
a_block_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_block
wise
_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_block
wise
_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write i + 1
// LDS write i + 1
a_block_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_block
wise
_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// global read i + 2
// global read i + 2
a_block_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_block
wise
_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
// LDS write i + 1
// LDS write i + 1
b_block_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
b_block
wise
_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// global read i + 2
// global read i + 2
b_block_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
b_block
wise
_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
++
i
;
++
i
;
}
while
(
i
<
(
num_loop
-
2
));
}
while
(
i
<
(
num_loop
-
2
));
...
@@ -109,128 +105,18 @@ struct GridwiseGemmPipeline_v2
...
@@ -109,128 +105,18 @@ struct GridwiseGemmPipeline_v2
block_sync_lds
();
block_sync_lds
();
// GEMM num_loop - 2
// GEMM num_loop - 2
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
block_sync_lds
();
// LDS write num_loop - 1
// LDS write num_loop - 1
a_block_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_block
wise
_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_block_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
b_block
wise
_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
block_sync_lds
();
block_sync_lds
();
// GEMM num_loop - 1
// GEMM num_loop - 1
}
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
template
<
bool
HasMainLoop
,
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
static
__device__
void
RunBlockGemmPipeline
(
ABlockBuffer
&
a_block_buf
,
BBlockBuffer
&
b_block_buf
,
const
BlockwiseGemm
&
block_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
block_sync_lds
();
// GEMM i
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
// move to i + 2
// LDS write i + 1
// global read i + 2
// LDS write i + 1
// global read i + 2
++
i
;
}
while
(
i
<
(
num_loop
-
2
));
}
// tail
{
block_sync_lds
();
// GEMM num_loop - 2
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
// LDS write num_loop - 1
block_sync_lds
();
// GEMM num_loop - 1
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
static
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_block_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_block_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
block_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
if
(
ABBlockTransferThreadGroup
::
IsBelong
())
{
RunABBlockTransferPipeline
<
HasMainLoop
>
(
a_grid_desc
,
a_block_desc
,
a_block_copy
,
a_grid_buf
,
a_block_buf
,
a_block_copy_step
,
b_grid_desc
,
b_block_desc
,
b_block_copy
,
b_grid_buf
,
b_block_buf
,
b_block_copy_step
,
num_loop
);
}
else
if
(
BlockGemmThreadGroup
::
IsBelong
())
{
RunBlockGemmPipeline
<
HasMainLoop
>
(
a_block_buf
,
b_block_buf
,
block_gemm
,
c_thread_buf
,
num_loop
);
}
}
}
}
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
5b3bd032
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "gridwise_gemm_pipeline_v2.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -127,7 +128,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -127,7 +128,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
#if 1
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
#else
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v2
;
#endif
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle
_v2
.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_
producer_consumer_
cshuffle.hpp
View file @
5b3bd032
...
@@ -7,8 +7,7 @@
...
@@ -7,8 +7,7 @@
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "gridwise_gemm_pipeline_producer_consumer.hpp"
#include "gridwise_gemm_pipeline_v2.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -27,18 +26,20 @@ __global__ void
...
@@ -27,18 +26,20 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_xdl_cshuffle_v2
(
const
FloatAB
*
__restrict__
p_a_grid
,
kernel_gemm_xdl_producer_consumer_cshuffle
(
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
AElementwiseOperation
a_element_op
,
FloatC
*
__restrict__
p_c_grid
,
const
BElementwiseOperation
b_element_op
,
const
AElementwiseOperation
a_element_op
,
const
CElementwiseOperation
c_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
CElementwiseOperation
c_element_op
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
Block2CTileMap
block_2_ctile_map
)
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
@@ -52,6 +53,18 @@ __global__ void
...
@@ -52,6 +53,18 @@ __global__ void
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_ctile_map
;
#endif // end of #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
typename
FloatAB
,
template
<
typename
FloatAB
,
...
@@ -97,7 +110,7 @@ template <typename FloatAB,
...
@@ -97,7 +110,7 @@ template <typename FloatAB,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
>
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle
_v2
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdl_
producer_consumer_
cshuffle
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -114,10 +127,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -114,10 +127,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
ABBlockTransferThreadGroupSize
+
BlockGemmThreadGroupSize
>
;
#if 0
struct
ABBlockTransferThreadGroup
struct
ABBlockTransferThreadGroup
{
{
__device__
static
constexpr
index_t
GetNumOfThread
()
__device__
static
constexpr
index_t
GetNumOfThread
()
...
@@ -151,22 +160,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -151,22 +160,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
}
}
};
};
using CShuffleBlockTransferThreadGroup = ThisThreadBlock;
using
CShuffleBlockTransferThreadGroup
=
#else
ThisThreadBlock
<
ABBlockTransferThreadGroupSize
+
BlockGemmThreadGroupSize
>
;
using
ABBlockTransferThreadGroup
=
ThisThreadBlock
;
using
BlockGemmThreadGroup
=
ThisThreadBlock
;
using
CShuffleBlockTransferThreadGroup
=
ThisThreadBlock
;
#endif
#if 1
using
GridwiseGemmPipe
=
GridwiseGemmPipelineProducerConsumer
<
ABBlockTransferThreadGroup
,
// gridwise GEMM pipeline
BlockGemmThreadGroup
,
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
NumGemmKPrefetchStage
>
;
#else
// gridwise GEMM pipeline
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v2
<
ABBlockTransferThreadGroup
,
BlockGemmThreadGroup
,
NumGemmKPrefetchStage
>
;
#endif
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
{
...
...
library/src/utility/conv_fwd_util.cpp
View file @
5b3bd032
...
@@ -37,16 +37,16 @@ std::size_t get_flops(ck::index_t N,
...
@@ -37,16 +37,16 @@ std::size_t get_flops(ck::index_t N,
}
}
ConvParams
::
ConvParams
()
ConvParams
::
ConvParams
()
:
num_dim_spatial
(
2
),
:
num_dim_spatial
(
2
),
N
(
128
),
N
(
128
),
K
(
256
),
K
(
256
),
C
(
192
),
C
(
192
),
filter_spatial_lengths
(
2
,
3
),
filter_spatial_lengths
(
2
,
3
),
input_spatial_lengths
(
2
,
71
),
input_spatial_lengths
(
2
,
71
),
conv_filter_strides
(
2
,
2
),
conv_filter_strides
(
2
,
2
),
conv_filter_dilations
(
2
,
1
),
conv_filter_dilations
(
2
,
1
),
input_left_pads
(
2
,
1
),
input_left_pads
(
2
,
1
),
input_right_pads
(
2
,
1
)
input_right_pads
(
2
,
1
)
{
{
}
}
...
@@ -77,9 +77,9 @@ ConvParams::ConvParams(ck::index_t n_dim,
...
@@ -77,9 +77,9 @@ ConvParams::ConvParams(ck::index_t n_dim,
conv_filter_dilations
.
size
()
!=
num_dim_spatial
||
conv_filter_dilations
.
size
()
!=
num_dim_spatial
||
input_left_pads
.
size
()
!=
num_dim_spatial
||
input_right_pads
.
size
()
!=
num_dim_spatial
)
input_left_pads
.
size
()
!=
num_dim_spatial
||
input_right_pads
.
size
()
!=
num_dim_spatial
)
{
{
throw
(
std
::
runtime_error
(
throw
(
"ConvParams::GetOutputSpatialLengths: "
std
::
runtime_error
(
"ConvParams::GetOutputSpatialLengths: "
"parameter size is different from number of declared dimensions!"
));
"parameter size is different from number of declared dimensions!"
));
}
}
}
}
...
@@ -91,9 +91,9 @@ std::vector<ck::index_t> ConvParams::GetOutputSpatialLengths() const
...
@@ -91,9 +91,9 @@ std::vector<ck::index_t> ConvParams::GetOutputSpatialLengths() const
conv_filter_dilations
.
size
()
!=
num_dim_spatial
||
conv_filter_dilations
.
size
()
!=
num_dim_spatial
||
input_left_pads
.
size
()
!=
num_dim_spatial
||
input_right_pads
.
size
()
!=
num_dim_spatial
)
input_left_pads
.
size
()
!=
num_dim_spatial
||
input_right_pads
.
size
()
!=
num_dim_spatial
)
{
{
throw
(
std
::
runtime_error
(
throw
(
"ConvParams::GetOutputSpatialLengths: "
std
::
runtime_error
(
"ConvParams::GetOutputSpatialLengths: "
"parameter size is different from number of declared dimensions!"
));
"parameter size is different from number of declared dimensions!"
));
}
}
std
::
vector
<
ck
::
index_t
>
out_spatial_len
(
num_dim_spatial
,
0
);
std
::
vector
<
ck
::
index_t
>
out_spatial_len
(
num_dim_spatial
,
0
);
...
@@ -101,8 +101,7 @@ std::vector<ck::index_t> ConvParams::GetOutputSpatialLengths() const
...
@@ -101,8 +101,7 @@ std::vector<ck::index_t> ConvParams::GetOutputSpatialLengths() const
{
{
// XEff = (X - 1) * conv_dilation_w + 1;
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const
ck
::
index_t
idx_eff
=
const
ck
::
index_t
idx_eff
=
(
filter_spatial_lengths
[
i
]
-
1
)
*
conv_filter_dilations
[
i
]
+
1
;
(
filter_spatial_lengths
[
i
]
-
1
)
*
conv_filter_dilations
[
i
]
+
1
;
out_spatial_len
[
i
]
=
out_spatial_len
[
i
]
=
(
input_spatial_lengths
[
i
]
+
input_left_pads
[
i
]
+
input_right_pads
[
i
]
-
idx_eff
)
/
(
input_spatial_lengths
[
i
]
+
input_left_pads
[
i
]
+
input_right_pads
[
i
]
-
idx_eff
)
/
conv_filter_strides
[
i
]
+
conv_filter_strides
[
i
]
+
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment