Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4a96c2e4
"...composable_kernel.git" did not exist on "66edb2590d47a4dd4208e10998b19d0318b1cd71"
Commit
4a96c2e4
authored
Apr 21, 2022
by
Chao Liu
Browse files
refactor
parent
70e8cc76
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
402 additions
and
587 deletions
+402
-587
example/01_gemm/gemm_xdl_fp16.cpp
example/01_gemm/gemm_xdl_fp16.cpp
+8
-8
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
+4
-3
include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp
.../tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp
+4
-3
include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp
...nsor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp
+4
-3
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+82
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+56
-128
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
+62
-71
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+88
-118
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
+46
-94
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+47
-158
include/ck/utility/thread_group.hpp
include/ck/utility/thread_group.hpp
+1
-1
No files found.
example/01_gemm/gemm_xdl_fp16.cpp
View file @
4a96c2e4
...
@@ -11,9 +11,8 @@
...
@@ -11,9 +11,8 @@
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_tensor.hpp"
//#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl.hpp"
//#include "device_gemm_xdl_c_shuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
//#include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_cshuffle_v2.hpp"
#include "device_gemm_xdl_cshuffle_v2.hpp"
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "reference_gemm.hpp"
...
@@ -53,12 +52,12 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
...
@@ -53,12 +52,12 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// // 1-stage prefetch
// // 1-stage prefetch
< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
//
< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// // 2-stage prefetch
// // 2-stage prefetch
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
#elif
1
#elif
0
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle_v2
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle_v2
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
...
@@ -67,7 +66,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
...
@@ -67,7 +66,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// all thread
// all thread
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
0
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
0
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 0, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 0, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
#elif
1
#elif
0
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle_v2
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle_v2
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
...
@@ -82,7 +81,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
...
@@ -82,7 +81,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
;
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
;
// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
#endif
#endif
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
View file @
4a96c2e4
...
@@ -317,13 +317,14 @@ struct DeviceGemmXdl
...
@@ -317,13 +317,14 @@ struct DeviceGemmXdl
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
const
bool
has_main_k
0
_block_loop
=
GridwiseGemm
::
CalculateHasMainK
0
BlockLoop
(
K
0
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
);
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
has_main_k
0
_block_loop
)
if
(
has_main_k_block_loop
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
GridwiseGemm
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp
View file @
4a96c2e4
...
@@ -462,13 +462,14 @@ struct DeviceGemm_Xdl_CShuffle
...
@@ -462,13 +462,14 @@ struct DeviceGemm_Xdl_CShuffle
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
);
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
const
bool
has_main_k
0
_block_loop
=
GridwiseGemm
::
CalculateHasMainK
0
BlockLoop
(
K
0
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
);
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
has_main_k
0
_block_loop
)
if
(
has_main_k_block_loop
)
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
GridwiseGemm
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp
View file @
4a96c2e4
...
@@ -464,13 +464,14 @@ struct DeviceGemm_Xdl_CShuffle_v2
...
@@ -464,13 +464,14 @@ struct DeviceGemm_Xdl_CShuffle_v2
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
);
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
const
bool
has_main_k
0
_block_loop
=
GridwiseGemm
::
CalculateHasMainK
0
BlockLoop
(
K
0
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
);
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
has_main_k
0
_block_loop
)
if
(
has_main_k_block_loop
)
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v2
<
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
GridwiseGemm
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
4a96c2e4
...
@@ -17,6 +17,88 @@ namespace ck {
...
@@ -17,6 +17,88 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
GemmDesc
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
,
index_t
MaxGroupCount
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdlops_v2r3
(
const
StaticallyIndexedArray
<
GemmDesc
,
MaxGroupCount
>
gemm_desc_
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
#if 1
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
block_id
>=
gemm_desc_
[
i
].
BlockStart_
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd_
&&
i
<
group_count
)
{
auto
group_id
=
i
;
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
gemm_desc_
[
group_id
].
a_ptr
,
gemm_desc_
[
group_id
].
b_ptr
,
gemm_desc_
[
group_id
].
c_ptr
,
p_shared
,
gemm_desc_
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_desc_
[
group_id
].
b_grid_desc_k0_n_k1_
,
gemm_desc_
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
a_element_op
,
b_element_op
,
c_element_op
,
gemm_desc_
[
group_id
].
grouped_gemm_block_2_ctile_map_
);
}
});
#else
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
&
gemm_desc_
);
index_t
group_id
=
0
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
group_id
=
(
block_id
>=
gemm_desc_
[
i
].
BlockStart
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd
&&
i
<
group_count
)
?
i
:
group_id
;
});
const
index_t
block_id_grp
=
block_id
-
gemm_desc_ptr
[
group_id
].
BlockStart
;
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
gemm_desc_ptr
[
group_id
].
a_ptr
,
gemm_desc_ptr
[
group_id
].
b_ptr
,
gemm_desc_ptr
[
group_id
].
c_ptr
,
p_shared
,
gemm_desc_ptr
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_desc_ptr
[
group_id
].
b_grid_desc_k0_n_k1_
,
gemm_desc_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
a_element_op
,
b_element_op
,
c_element_op
,
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
,
block_id_grp
);
#endif
#else
ignore
=
gemm_desc_
;
ignore
=
group_count
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
CDataType
,
typename
CDataType
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
4a96c2e4
#ifndef CK_GRIDWISE_GEMM_PIPELINE_V1_HPP
#pragma once
#define CK_GRIDWISE_GEMM_PIPELINE_V1_HPP
#include "common_header.hpp"
#include "common_header.hpp"
namespace
ck
{
namespace
ck
{
template
<
typename
AGridDesc
,
template
<
index_t
NumPrefetch
>
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
,
index_t
NumPrefetch
,
bool
HasMainLoop
>
struct
GridwiseGemmPipeline_v1
;
struct
GridwiseGemmPipeline_v1
;
// 1-stage prefetch
// 1-stage prefetch
template
<
typename
AGridDesc
,
template
<
>
typename
ABlockDesc
,
struct
GridwiseGemmPipeline_v1
<
1
>
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
,
bool
HasMainLoop
>
struct
GridwiseGemmPipeline_v1
<
AGridDesc
,
ABlockDesc
,
ABlockTransfer
,
AGridBuffer
,
ABlockBuffer
,
ABlockTransferStep
,
BGridDesc
,
BBlockDesc
,
BBlockTransfer
,
BGridBuffer
,
BBlockBuffer
,
BBlockTransferStep
,
BlockwiseGemm
,
CThreadBuffer
,
1
,
HasMainLoop
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
const
AGridBuffer
&
a_grid_buf
,
...
@@ -75,51 +51,7 @@ struct GridwiseGemmPipeline_v1<AGridDesc,
...
@@ -75,51 +51,7 @@ struct GridwiseGemmPipeline_v1<AGridDesc,
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
index_t
num_loop
)
{
{
#if 0
#if 1
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
#elif
0
// preload data into LDS
// preload data into LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
...
@@ -244,41 +176,38 @@ struct GridwiseGemmPipeline_v1<AGridDesc,
...
@@ -244,41 +176,38 @@ struct GridwiseGemmPipeline_v1<AGridDesc,
};
};
// 2-stage prefetch
// 2-stage prefetch
template
<
typename
AGridDesc
,
template
<
>
typename
ABlockDesc
,
struct
GridwiseGemmPipeline_v1
<
2
>
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
,
bool
HasMainLoop
>
struct
GridwiseGemmPipeline_v1
<
AGridDesc
,
ABlockDesc
,
ABlockTransfer
,
AGridBuffer
,
ABlockBuffer
,
ABlockTransferStep
,
BGridDesc
,
BBlockDesc
,
BBlockTransfer
,
BGridBuffer
,
BBlockBuffer
,
BBlockTransferStep
,
BlockwiseGemm
,
CThreadBuffer
,
2
,
HasMainLoop
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
{
// TODO: improve applicability
return
num_loop
%
2
==
0
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
(
num_loop
/
2
)
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
static
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
static
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
ABlockTransfer
&
a_blockwise_copy
,
...
@@ -395,4 +324,3 @@ struct GridwiseGemmPipeline_v1<AGridDesc,
...
@@ -395,4 +324,3 @@ struct GridwiseGemmPipeline_v1<AGridDesc,
};
};
}
// namespace ck
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
View file @
4a96c2e4
...
@@ -4,62 +4,8 @@
...
@@ -4,62 +4,8 @@
namespace
ck
{
namespace
ck
{
template
<
typename
ABBlockTransferThreadGroup
,
template
<
typename
ABBlockTransferThreadGroup
,
typename
BlockGemmThreadGroup
>
typename
BlockGemmThreadGroup
,
struct
GridwiseGemmPipeline_v2
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
,
index_t
NumPrefetch
,
bool
HasMainLoop
>
struct
GridwiseGemmPipeline_v2
;
// 1-stage prefetch
template
<
typename
ABBlockTransferThreadGroup
,
typename
BlockGemmThreadGroup
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
,
bool
HasMainLoop
>
struct
GridwiseGemmPipeline_v2
<
ABBlockTransferThreadGroup
,
BlockGemmThreadGroup
,
AGridDesc
,
ABlockDesc
,
ABlockTransfer
,
AGridBuffer
,
ABlockBuffer
,
ABlockTransferStep
,
BGridDesc
,
BBlockDesc
,
BBlockTransfer
,
BGridBuffer
,
BBlockBuffer
,
BBlockTransferStep
,
BlockwiseGemm
,
CThreadBuffer
,
1
,
HasMainLoop
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -69,6 +15,31 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
...
@@ -69,6 +15,31 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
// TODO static assert
// TODO static assert
}
}
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
{
// TODO: improve applicability
return
num_loop
%
2
==
0
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
/
2
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
>
static
__device__
void
RunABBlockTransferPipeline
(
const
AGridDesc
&
a_grid_desc
,
static
__device__
void
RunABBlockTransferPipeline
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_block_copy
,
ABlockTransfer
&
a_block_copy
,
...
@@ -151,6 +122,11 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
...
@@ -151,6 +122,11 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
}
}
}
}
template
<
bool
HasMainLoop
,
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
static
__device__
void
RunBlockGemmPipeline
(
ABlockBuffer
&
a_block_buf
,
static
__device__
void
RunBlockGemmPipeline
(
ABlockBuffer
&
a_block_buf
,
BBlockBuffer
&
b_block_buf
,
BBlockBuffer
&
b_block_buf
,
const
BlockwiseGemm
&
block_gemm
,
const
BlockwiseGemm
&
block_gemm
,
...
@@ -161,7 +137,6 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
...
@@ -161,7 +137,6 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
c_thread_buf
.
Clear
();
c_thread_buf
.
Clear
();
// main body
// main body
// FIXME: HasMainLoop = (num_loop) > 2
if
constexpr
(
HasMainLoop
)
if
constexpr
(
HasMainLoop
)
{
{
index_t
i
=
0
;
index_t
i
=
0
;
...
@@ -205,6 +180,21 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
...
@@ -205,6 +180,21 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
}
}
}
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
static
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
static
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_block_copy
,
ABlockTransfer
&
a_block_copy
,
...
@@ -223,23 +213,24 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
...
@@ -223,23 +213,24 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
{
{
if
(
ABBlockTransferThreadGroup
::
IsBelong
())
if
(
ABBlockTransferThreadGroup
::
IsBelong
())
{
{
RunABBlockTransferPipeline
(
a_grid_desc
,
RunABBlockTransferPipeline
<
HasMainLoop
>
(
a_grid_desc
,
a_block_desc
,
a_block_desc
,
a_block_copy
,
a_block_copy
,
a_grid_buf
,
a_grid_buf
,
a_block_buf
,
a_block_buf
,
a_block_copy_step
,
a_block_copy_step
,
b_grid_desc
,
b_grid_desc
,
b_block_desc
,
b_block_desc
,
b_block_copy
,
b_block_copy
,
b_grid_buf
,
b_grid_buf
,
b_block_buf
,
b_block_buf
,
b_block_copy_step
,
b_block_copy_step
,
num_loop
);
num_loop
);
}
}
else
if
(
BlockGemmThreadGroup
::
IsBelong
())
else
if
(
BlockGemmThreadGroup
::
IsBelong
())
{
{
RunBlockGemmPipeline
(
a_block_buf
,
b_block_buf
,
block_gemm
,
c_thread_buf
,
num_loop
);
RunBlockGemmPipeline
<
HasMainLoop
>
(
a_block_buf
,
b_block_buf
,
block_gemm
,
c_thread_buf
,
num_loop
);
}
}
}
}
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
4a96c2e4
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "
blockwise
_tensor_slice_transfer_v4r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v4r1.hpp"
#include "
blockwise
_tensor_slice_transfer_v6r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
...
@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
...
@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
bool
HasMainK
0
BlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -41,17 +41,17 @@ __global__ void
...
@@ -41,17 +41,17 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK
0
BlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
...
@@ -125,6 +125,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -125,6 +125,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
...
@@ -190,10 +194,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -190,10 +194,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
{
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
// "wrong! K1 need to be known at compile-time");
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
"Invalid tuning param!"
);
...
@@ -208,21 +208,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -208,21 +208,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
return
false
;
return
false
;
// check NumGemmKPrefetchStage
// check gridwise gemm pipeline
if
constexpr
(
NumGemmKPrefetchStage
==
1
)
const
auto
num_k_loop
=
K
/
KPerBlock
;
{
// 1-stage prefetch always supported
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
}
else
if
constexpr
(
NumGemmKPrefetchStage
==
2
)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K
/
KPerBlock
)
%
2
==
0
))
{
return
false
;
}
}
else
{
{
return
false
;
return
false
;
}
}
...
@@ -242,12 +231,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -242,12 +231,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return
grid_size
;
return
grid_size
;
}
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
{
const
bool
has_main_k0_block_loop
=
((
K0
*
AK1
)
/
(
NumGemmKPrefetchStage
*
KPerBlock
))
>
1
;
const
index_t
num_loop
=
K
/
KPerBlock
;
return
has_main_k0_block
_loop
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num
_loop
)
;
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -315,7 +303,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -315,7 +303,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
template
<
bool
HasMainK
0
BlockLoop
,
typename
Block2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
@@ -358,28 +346,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -358,28 +346,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
AElementwiseOperation
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
2
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
true
,
NumGemmKPrefetchStage
>
(
NumGemmKPrefetchStage
>
(
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_element_op
,
...
@@ -389,28 +377,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -389,28 +377,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
BElementwiseOperation
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
2
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
true
,
NumGemmKPrefetchStage
>
(
NumGemmKPrefetchStage
>
(
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_element_op
,
...
@@ -429,7 +417,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -429,7 +417,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
Block
Size
,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
ThisThread
Block
,
FloatAB
,
FloatAB
,
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
...
@@ -457,43 +445,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -457,43 +445,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
// gridwise GEMM pipeline
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumGemmKPrefetchStage
,
HasMainK0BlockLoop
>
{};
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
KPerBlock
);
g
ridwise
_g
emm
_p
ipe
line
.
Run
(
a_grid_desc_ak0_m_ak1
,
G
ridwise
G
emm
P
ipe
::
template
Run
<
HasMainKBlockLoop
>
(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_blockwise_copy
,
a_grid_buf
,
a_grid_buf
,
a_block_buf
,
a_block_buf
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_blockwise_copy
,
b_grid_buf
,
b_grid_buf
,
b_block_buf
,
b_block_buf
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
blockwise_gemm
,
blockwise_gemm
,
c_thread_buf
,
c_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
);
// shuffle C and write out
// shuffle C and write out
{
{
...
@@ -609,8 +579,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -609,8 +579,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
Blockwise
TensorSliceTransfer_v6r1
<
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroup
TensorSliceTransfer_v6r1
<
BlockSize
,
// index_t BlockSize,
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
Sequence
<
1
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
View file @
4a96c2e4
...
@@ -22,7 +22,7 @@ template <typename GridwiseGemm,
...
@@ -22,7 +22,7 @@ template <typename GridwiseGemm,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
bool
HasMainK
0
BlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -41,17 +41,17 @@ __global__ void
...
@@ -41,17 +41,17 @@ __global__ void
{
{
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK
0
BlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
block_2_ctile_map
);
}
}
template
<
typename
FloatAB
,
template
<
typename
FloatAB
,
...
@@ -115,7 +115,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -115,7 +115,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
using
ThisThreadBlock
=
using
ThisThreadBlock
=
Any
ThreadBlock
<
ABBlockTransferThreadGroupSize
+
BlockGemmThreadGroupSize
>
;
This
ThreadBlock
<
ABBlockTransferThreadGroupSize
+
BlockGemmThreadGroupSize
>
;
#if 0
#if 0
struct ABBlockTransferThreadGroup
struct ABBlockTransferThreadGroup
...
@@ -158,6 +158,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -158,6 +158,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
using
CShuffleBlockTransferThreadGroup
=
ThisThreadBlock
;
using
CShuffleBlockTransferThreadGroup
=
ThisThreadBlock
;
#endif
#endif
#if 1
// gridwise GEMM pipeline
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
#else
// gridwise GEMM pipeline
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v2
<
ABBlockTransferThreadGroup
,
BlockGemmThreadGroup
,
NumGemmKPrefetchStage
>
;
#endif
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
...
@@ -223,10 +233,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -223,10 +233,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
{
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
// "wrong! K1 need to be known at compile-time");
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
"Invalid tuning param!"
);
...
@@ -241,21 +247,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -241,21 +247,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
return
false
;
return
false
;
// check NumGemmKPrefetchStage
// check gridwise gemm pipeline
if
constexpr
(
NumGemmKPrefetchStage
==
1
)
const
auto
num_k_loop
=
K
/
KPerBlock
;
{
// 1-stage prefetch always supported
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
}
else
if
constexpr
(
NumGemmKPrefetchStage
==
2
)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K
/
KPerBlock
)
%
2
==
0
))
{
return
false
;
}
}
else
{
{
return
false
;
return
false
;
}
}
...
@@ -275,12 +270,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -275,12 +270,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
return
grid_size
;
return
grid_size
;
}
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
{
const
bool
has_main_k0_block_loop
=
((
K0
*
AK1
)
/
(
NumGemmKPrefetchStage
*
KPerBlock
))
>
1
;
const
index_t
num_loop
=
K
/
KPerBlock
;
return
has_main_k0_block
_loop
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num
_loop
)
;
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -348,7 +342,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -348,7 +342,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
template
<
bool
HasMainK
0
BlockLoop
,
typename
Block2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
@@ -493,63 +487,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
...
@@ -493,63 +487,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
KPerBlock
);
#if 1
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
// gridwise GEMM pipeline
a_block_desc_ak0_m_ak1
,
const
auto
gridwise_gemm_pipeline
=
a_blockwise_copy
,
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_ak0_m_ak1
)
>
,
a_grid_buf
,
remove_cvref_t
<
decltype
(
a_block_desc_ak0_m_ak1
)
>
,
a_block_buf
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
a_block_slice_copy_step
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
b_grid_desc_bk0_n_bk1
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
b_block_desc_bk0_n_bk1
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
b_blockwise_copy
,
remove_cvref_t
<
decltype
(
b_grid_desc_bk0_n_bk1
)
>
,
b_grid_buf
,
remove_cvref_t
<
decltype
(
b_block_desc_bk0_n_bk1
)
>
,
b_block_buf
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
b_block_slice_copy_step
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
blockwise_gemm
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
c_thread_buf
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
num_k_block_main_loop
);
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumGemmKPrefetchStage
,
HasMainK0BlockLoop
>
{};
#else
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v2
<
ABBlockTransferThreadGroup
,
BlockGemmThreadGroup
,
remove_cvref_t
<
decltype
(
a_grid_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumGemmKPrefetchStage
,
HasMainK0BlockLoop
>
{};
#endif
gridwise_gemm_pipeline
.
Run
(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C and write out
// shuffle C and write out
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
4a96c2e4
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
#pragma once
#define CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
...
@@ -22,7 +20,7 @@ template <typename GridwiseGemm,
...
@@ -22,7 +20,7 @@ template <typename GridwiseGemm,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
bool
HasMainK
0
BlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -42,17 +40,17 @@ __global__ void
...
@@ -42,17 +40,17 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK
0
BlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared
,
p_shared
,
a_grid_desc_k0_m_k1
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
...
@@ -67,88 +65,6 @@ __global__ void
...
@@ -67,88 +65,6 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
GemmDesc
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainK0BlockLoop
,
index_t
MaxGroupCount
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdlops_v2r3
(
const
StaticallyIndexedArray
<
GemmDesc
,
MaxGroupCount
>
gemm_desc_
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
#if 1
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
block_id
>=
gemm_desc_
[
i
].
BlockStart_
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd_
&&
i
<
group_count
)
{
auto
group_id
=
i
;
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
gemm_desc_
[
group_id
].
a_ptr
,
gemm_desc_
[
group_id
].
b_ptr
,
gemm_desc_
[
group_id
].
c_ptr
,
p_shared
,
gemm_desc_
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_desc_
[
group_id
].
b_grid_desc_k0_n_k1_
,
gemm_desc_
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
a_element_op
,
b_element_op
,
c_element_op
,
gemm_desc_
[
group_id
].
grouped_gemm_block_2_ctile_map_
);
}
});
#else
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
&
gemm_desc_
);
index_t
group_id
=
0
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
group_id
=
(
block_id
>=
gemm_desc_
[
i
].
BlockStart
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd
&&
i
<
group_count
)
?
i
:
group_id
;
});
const
index_t
block_id_grp
=
block_id
-
gemm_desc_ptr
[
group_id
].
BlockStart
;
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
gemm_desc_ptr
[
group_id
].
a_ptr
,
gemm_desc_ptr
[
group_id
].
b_ptr
,
gemm_desc_ptr
[
group_id
].
c_ptr
,
p_shared
,
gemm_desc_ptr
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_desc_ptr
[
group_id
].
b_grid_desc_k0_n_k1_
,
gemm_desc_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
a_element_op
,
b_element_op
,
c_element_op
,
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
,
block_id_grp
);
#endif
#else
ignore
=
gemm_desc_
;
ignore
=
group_count
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatAcc
,
...
@@ -187,7 +103,7 @@ template <index_t BlockSize,
...
@@ -187,7 +103,7 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
NumPrefetch
=
1
>
index_t
Num
GemmK
Prefetch
Stage
=
1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -202,6 +118,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -202,6 +118,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// K1 should be Number<...>
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -291,21 +211,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -291,21 +211,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
return
false
;
// check NumPrefetch
// check gridwise gemm pipeline
if
constexpr
(
NumPrefetch
==
1
)
const
auto
num_k_loop
=
K0
/
K0PerBlock
;
{
// 1-stage prefetch always supported
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
}
else
if
constexpr
(
NumPrefetch
==
2
)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K0
/
K0PerBlock
)
%
2
==
0
))
{
return
false
;
}
}
else
{
{
return
false
;
return
false
;
}
}
...
@@ -335,12 +244,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -335,12 +244,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return
grid_size
;
return
grid_size
;
}
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
{
const
bool
has_main_k0_block
_loop
=
(
K0
/
(
NumPrefetch
*
K0PerBlock
))
>
1
;
const
index_t
num
_loop
=
K
/
(
K0PerBlock
*
K1
)
;
return
has_main_k0_block
_loop
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num
_loop
)
;
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -379,7 +287,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -379,7 +287,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}();
}();
using
BlockwiseGemm
=
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
Block
Size
,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
ThisThread
Block
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
...
@@ -433,7 +341,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -433,7 +341,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
using
DefaultBlock2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
using
DefaultBlock2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
template
<
bool
HasMainK
0
BlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
@@ -499,7 +407,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -499,7 +407,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
true
,
NumPrefetch
>
(
Num
GemmK
Prefetch
Stage
>
(
a_grid_desc_k0_m_k1
,
a_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_element_op
,
...
@@ -530,7 +438,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -530,7 +438,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
true
,
NumPrefetch
>
(
Num
GemmK
Prefetch
Stage
>
(
b_grid_desc_k0_n_k1
,
b_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_element_op
,
...
@@ -547,7 +455,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -547,7 +455,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// sanity check
// sanity check
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
Block
Size
,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
ThisThread
Block
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
...
@@ -575,41 +483,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -575,41 +483,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
// gridwise GEMM pipeline
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_k0_m_k1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_k0_m_k1
)
>
,
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_k0_m_k1
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
a_block_desc_k0_m_k1
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
a_blockwise_copy
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
a_grid_buf
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
a_block_buf
,
remove_cvref_t
<
decltype
(
b_grid_desc_k0_n_k1
)
>
,
a_block_slice_copy_step
,
remove_cvref_t
<
decltype
(
b_block_desc_k0_n_k1
)
>
,
b_grid_desc_k0_n_k1
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
b_block_desc_k0_n_k1
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
b_blockwise_copy
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
b_grid_buf
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
b_block_buf
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
b_block_slice_copy_step
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
blockwise_gemm
,
NumPrefetch
,
c_thread_buf
,
HasMainK0BlockLoop
>
{};
num_k_block_main_loop
);
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
gridwise_gemm_pipeline
.
Run
(
a_grid_desc_k0_m_k1
,
a_block_desc_k0_m_k1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_k0_n_k1
,
b_block_desc_k0_n_k1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
K0BlockMainLoop
);
// output: register to global memory
// output: register to global memory
{
{
...
@@ -692,4 +582,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -692,4 +582,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
};
};
}
// namespace ck
}
// namespace ck
#endif
include/ck/utility/thread_group.hpp
View file @
4a96c2e4
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
namespace
ck
{
namespace
ck
{
template
<
index_t
ThreadPerBlock
>
template
<
index_t
ThreadPerBlock
>
struct
Any
ThreadBlock
struct
This
ThreadBlock
{
{
static
constexpr
index_t
kNumThread_
=
ThreadPerBlock
;
static
constexpr
index_t
kNumThread_
=
ThreadPerBlock
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment