Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d656321b
Commit
d656321b
authored
Jun 14, 2022
by
wangshaojie6
Browse files
add 4 stage one
parent
575a50dd
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1650 additions
and
175 deletions
+1650
-175
example/01_gemm/gemm_xdl_fp16.cpp
example/01_gemm/gemm_xdl_fp16.cpp
+12
-0
include/ck/config.hpp
include/ck/config.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_gemm_xdl_producer_consumer_cshuffle.hpp
...gpu/device/device_gemm_xdl_producer_consumer_cshuffle.hpp
+656
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_producer_consumer.hpp
...ion/gpu/grid/gridwise_gemm_pipeline_producer_consumer.hpp
+234
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
+36
-173
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_producer_consumer_cshuffle.hpp
...gpu/grid/gridwise_gemm_xdl_producer_consumer_cshuffle.hpp
+710
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+1
-1
No files found.
example/01_gemm/gemm_xdl_fp16.cpp
View file @
d656321b
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "gemm_specialization.hpp"
#include "device_gemm_xdl_producer_consumer_cshuffle.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -43,12 +44,23 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough;
...
@@ -43,12 +44,23 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
// clang-format off
#if 0
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>;
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>;
#else
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_ProducerConsumer_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
8
>
,
8
>
;
#endif
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
include/ck/config.hpp
View file @
d656321b
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#define CK_USE_LAUNCH_BOUNDS 1
#define CK_USE_LAUNCH_BOUNDS 1
#ifdef CK_USE_LAUNCH_BOUNDS
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK
256
#define CK_MAX_THREAD_PER_BLOCK
512
#define CK_MIN_BLOCK_PER_CU 1
#define CK_MIN_BLOCK_PER_CU 1
#endif
#endif
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_producer_consumer_cshuffle.hpp
0 → 100644
View file @
d656321b
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_producer_consumer.hpp
0 → 100644
View file @
d656321b
#pragma once
#include "common_header.hpp"
namespace
ck
{
template
<
typename
ABBlockTransferThreadGroup
,
typename
BlockGemmThreadGroup
,
index_t
NumGemmKPrefetchStage
>
struct
GridwiseGemmPipelineProducerConsumer
;
// 1-stage prefetch
template
<
typename
ABBlockTransferThreadGroup
,
typename
BlockGemmThreadGroup
>
struct
GridwiseGemmPipelineProducerConsumer
<
ABBlockTransferThreadGroup
,
BlockGemmThreadGroup
,
1
>
{
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
{
// TODO: improve applicability
return
num_loop
%
2
==
0
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
/
2
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
>
static
__device__
void
RunABBlockTransferPipeline
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_block_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_block_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
index_t
num_loop
)
{
// global read 0
a_block_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_block_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// move to 1
a_block_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write 0
a_block_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// global Read 1
a_block_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
// LDS write 0
b_block_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// global Read 1
b_block_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
block_sync_lds
();
// GEMM i
block_sync_lds
();
// move to i + 2
a_block_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write i + 1
a_block_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// global read i + 2
a_block_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
// LDS write i + 1
b_block_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// global read i + 2
b_block_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
++
i
;
}
while
(
i
<
(
num_loop
-
2
));
}
// tail
{
block_sync_lds
();
// GEMM num_loop - 2
block_sync_lds
();
// LDS write num_loop - 1
a_block_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_block_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
block_sync_lds
();
// GEMM num_loop - 1
}
}
template
<
bool
HasMainLoop
,
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
static
__device__
void
RunBlockGemmPipeline
(
ABlockBuffer
&
a_block_buf
,
BBlockBuffer
&
b_block_buf
,
const
BlockwiseGemm
&
block_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
block_sync_lds
();
// GEMM i
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
// move to i + 2
// LDS write i + 1
// global read i + 2
// LDS write i + 1
// global read i + 2
++
i
;
}
while
(
i
<
(
num_loop
-
2
));
}
// tail
{
block_sync_lds
();
// GEMM num_loop - 2
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
// LDS write num_loop - 1
block_sync_lds
();
// GEMM num_loop - 1
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
static
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_block_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_block_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
block_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
if
(
ABBlockTransferThreadGroup
::
IsBelong
())
{
RunABBlockTransferPipeline
<
HasMainLoop
>
(
a_grid_desc
,
a_block_desc
,
a_block_copy
,
a_grid_buf
,
a_block_buf
,
a_block_copy_step
,
b_grid_desc
,
b_block_desc
,
b_block_copy
,
b_grid_buf
,
b_block_buf
,
b_block_copy_step
,
num_loop
);
}
else
if
(
BlockGemmThreadGroup
::
IsBelong
())
{
RunBlockGemmPipeline
<
HasMainLoop
>
(
a_block_buf
,
b_block_buf
,
block_gemm
,
c_thread_buf
,
num_loop
);
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
View file @
d656321b
...
@@ -459,12 +459,12 @@ struct GridwiseGemmPipeline_v2<4>
...
@@ -459,12 +459,12 @@ struct GridwiseGemmPipeline_v2<4>
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
{
{
// TODO: improve applicability
// TODO: improve applicability
return
num_loop
>
4
;
return
num_loop
%
4
==
0
;
}
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
{
return
num_loop
>
4
;
return
num_loop
/
4
>
1
;
}
}
template
<
bool
HasMainLoop
,
template
<
bool
HasMainLoop
,
...
@@ -498,34 +498,15 @@ struct GridwiseGemmPipeline_v2<4>
...
@@ -498,34 +498,15 @@ struct GridwiseGemmPipeline_v2<4>
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
index_t
num_loop
)
{
{
// global read 0
static_for
<
0
,
4
,
1
>
{}([
&
](
auto
i_pre
){
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
// global read i_pre
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I0
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
Number
<
i_pre
>
{});
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
Number
<
i_pre
>
{});
// move to 1
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// global read 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I1
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I1
);
// move to 2
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// global read 2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I2
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I2
);
// move to 3
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// global read 3
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I3
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I3
);
// move to i_pre + 1
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
});
// Initialize C
// Initialize C
c_thread_buf
.
Clear
();
c_thread_buf
.
Clear
();
...
@@ -536,167 +517,49 @@ struct GridwiseGemmPipeline_v2<4>
...
@@ -536,167 +517,49 @@ struct GridwiseGemmPipeline_v2<4>
{
{
do
do
{
{
// move to i + 4
static_for
<
0
,
4
,
1
>
{}([
&
](
auto
i_main
){
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write i
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
I0
);
// global Read i + 4
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
// LDS write i
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
I0
);
// global Read i + 4
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I0
);
block_sync_lds
();
// GEMM i
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
// LDS write i_main
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
i_main
>
{});
// move to i + 5
// global Read i_main + 3
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
Number
<
i_main
>
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write i + 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
I1
);
// global read i + 5
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I1
);
// LDS write i + 1
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
I1
);
// global read i + 5
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I1
);
block_sync_lds
();
// GEMM i + 1
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
// move to i + 6
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write i
+ 2
// LDS write i
_main
a
_blockwise_copy
.
RunWrite
(
a
_block_desc
,
a
_block_buf
,
I2
);
b
_blockwise_copy
.
RunWrite
(
b
_block_desc
,
b
_block_buf
,
Number
<
i_main
>
{}
);
// global
r
ead i +
6
// global
R
ead i
_main
+
3
a
_blockwise_copy
.
RunRead
(
a
_grid_desc
,
a
_grid_buf
,
I2
);
b
_blockwise_copy
.
RunRead
(
b
_grid_desc
,
b
_grid_buf
,
Number
<
i_main
>
{}
);
// LDS write i + 2
// move to i_main + 3
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
I2
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
// global read i + 6
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I2
);
block_sync_lds
();
block_sync_lds
();
// GEMM i
+ 2
// GEMM i
_main
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
block_sync_lds
();
});
// move to i + 7
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write i + 3
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
I3
);
// global read i + 7
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I3
);
// LDS write i + 3
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
I3
);
// global read i + 7
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I3
);
block_sync_lds
();
// GEMM i + 3
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
i
+=
4
;
i
+=
4
;
}
while
(
i
<
(
num_loop
-
4
));
}
while
(
i
<
(
num_loop
-
4
));
}
}
// tail
// tail
if
(
i
==
num_loop
-
4
)
static_for
<
0
,
I4
,
1
>
{}([
&
](
auto
i_res
){
{
static_for
<
0
,
I4
,
1
>
{}([
&
](
auto
i_res
){
// Write num_loop - 3
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
i_res
>
{});
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
Number
<
i_res
>
{});
block_sync_lds
();
// Write num_loop - 3
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
i_res
>
{});
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
Number
<
i_res
>
{});
// GEMM num_loop - 3
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
});
}
// tail
if
(
i
==
num_loop
-
3
)
{
static_for
<
0
,
I3
,
1
>
{}([
&
](
auto
i_res
){
// Write num_loop - 3
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
i_res
>
{});
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
Number
<
i_res
>
{});
block_sync_lds
();
// GEMM num_loop - 3
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
});
}
// tail
else
if
(
i
==
num_loop
-
2
)
{
static_for
<
0
,
I2
,
1
>
{}([
&
](
auto
i_res
){
// Write num_loop
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
i_res
>
{});
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
Number
<
i_res
>
{});
block_sync_lds
();
// GEMM num_loop
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
});
}
// tail
else
if
(
i
==
num_loop
-
1
)
{
static_for
<
0
,
I1
,
1
>
{}([
&
](
auto
i_res
){
// Write num_loop
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
i_res
>
{});
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
Number
<
i_res
>
{});
block_sync_lds
();
// GEMM num_loop
// GEMM num_loop
- 3
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
block_sync_lds
();
});
});
}
}
}
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_producer_consumer_cshuffle.hpp
0 → 100644
View file @
d656321b
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
d656321b
...
@@ -233,7 +233,7 @@ template <index_t BlockSize,
...
@@ -233,7 +233,7 @@ template <index_t BlockSize,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
bool
BBlockLdsExtraN1
=
false
,
bool
BBlockLdsExtraN1
=
false
,
index_t
NumGemmKPrefetchStage
=
3
>
index_t
NumGemmKPrefetchStage
=
4
>
struct
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
struct
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment