Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e889d086
Commit
e889d086
authored
Feb 14, 2025
by
feifei14119
Browse files
save 51
parent
e076a320
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
728 additions
and
90 deletions
+728
-90
example/ck_tile/18_flatmm/flatmm_basic.cpp
example/ck_tile/18_flatmm/flatmm_basic.cpp
+124
-0
example/ck_tile/18_flatmm/flatmm_basic.hpp
example/ck_tile/18_flatmm/flatmm_basic.hpp
+2
-2
example/ck_tile/18_flatmm/run_flatmm_example.inc
example/ck_tile/18_flatmm/run_flatmm_example.inc
+2
-2
include/ck_tile/ops/flatmm.hpp
include/ck_tile/ops/flatmm.hpp
+4
-4
include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp
...ile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp
+30
-10
include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp
include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp
+193
-22
include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp
...s/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp
+259
-38
include/ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp
...mm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp
+114
-12
No files found.
example/ck_tile/18_flatmm/flatmm_basic.cpp
View file @
e889d086
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "flatmm_basic.hpp"
#include "flatmm_basic.hpp"
#if 1
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
flatmm_calc
(
const
ck_tile
::
FlatmmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
flatmm_calc
(
const
ck_tile
::
FlatmmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
...
@@ -117,6 +118,129 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
...
@@ -117,6 +118,129 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
return
ave_time
;
return
ave_time
;
}
}
#else
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
flatmm_calc
(
const
ck_tile
::
FlatmmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
int
kBlockPerCu
=
1
;
// This part comes from the Codegen
/*constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 4;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;*/
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
64
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTile1DPartitioner
<
CodegenGemmShape
>
;
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CLayout
,
CodegenPipelineProblem
::
kBlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
M_Warp
,
N_Warp
,
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
,
CodegenPipelineProblem
::
TransposeC
>>
;
using
CodegenFlatmmPolicy
=
ck_tile
::
UniversalFlatmmPipelineAgBgCrPolicy
;
using
CodegenFlatmmPipeline
=
ck_tile
::
FlatmmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenFlatmmPolicy
>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
FlatmmKernel
<
TilePartitioner
,
CodegenFlatmmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
#if FEIFEI_DEBUG
/*using BlockFlatmmStruct = ck_tile::remove_cvref_t<decltype(CodegenFlatmmPolicy::template GetBlockFlatmm<CodegenPipelineProblem>())>;
auto block_flatmm = BlockFlatmmStruct(); // struct BlockFlatmmASmemBSmemCRegV1
//auto ADramTileDistr = CodegenFlatmmPolicy::template MakeADramTileDistribution<CodegenPipelineProblem>();
auto kernel = Kernel{};
using SplitKBatchOffset = typename Kernel::SplitKBatchOffset;
SplitKBatchOffset splitk_batch_offset(args);
auto gemm_tensor_views_tuple = Kernel::template MakeGemmTensorViews<ck_tile::memory_operation_enum::set>(
args.a_ptr,
args.b_shuffle_ptr,
args.c_ptr,
kargs, splitk_batch_offset);*/
printf
(
"[FEIFEI] --- flatmm_calc() ---
\n
"
);
printf
(
"[FEIFEI] BlockPerCu = %d
\n
"
,
static_cast
<
int
>
(
kBlockPerCu
));
printf
(
"[FEIFEI] BlockTile M = %d
\n
"
,
static_cast
<
int
>
(
M_Tile
));
printf
(
"[FEIFEI] BlockTile N = %d
\n
"
,
static_cast
<
int
>
(
N_Tile
));
printf
(
"[FEIFEI] BlockTile K = %d
\n
"
,
static_cast
<
int
>
(
K_Tile
));
printf
(
"[FEIFEI] WavePerBlock M = %d
\n
"
,
static_cast
<
int
>
(
M_Warp
));
printf
(
"[FEIFEI] WavePerBlock N = %d
\n
"
,
static_cast
<
int
>
(
N_Warp
));
printf
(
"[FEIFEI] WavePerBlock K = %d
\n
"
,
static_cast
<
int
>
(
K_Warp
));
printf
(
"[FEIFEI] WaveTile M = %d
\n
"
,
static_cast
<
int
>
(
M_Warp_Tile
));
printf
(
"[FEIFEI] WaveTile N = %d
\n
"
,
static_cast
<
int
>
(
N_Warp_Tile
));
printf
(
"[FEIFEI] WaveTile K = %d
\n
"
,
static_cast
<
int
>
(
K_Warp_Tile
));
printf
(
"[FEIFEI] grids = [%d, %d, %d]
\n
"
,
grids
.
x
,
grids
.
y
,
grids
.
z
);
printf
(
"[FEIFEI] blocks = [%d, %d, %d]
\n
"
,
blocks
.
x
,
blocks
.
y
,
blocks
.
z
);
#endif
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
{
throw
std
::
runtime_error
(
"Wrong! Arguments not supported! Skipping gemm!
\n
"
);
}
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
}
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
}
#endif
#include "run_flatmm_example.inc"
#include "run_flatmm_example.inc"
...
...
example/ck_tile/18_flatmm/flatmm_basic.hpp
View file @
e889d086
...
@@ -80,12 +80,12 @@ auto create_args(int argc, char* argv[])
...
@@ -80,12 +80,12 @@ auto create_args(int argc, char* argv[])
.
insert
(
"n"
,
"128"
,
"n dimension"
)
// 128, 4096
.
insert
(
"n"
,
"128"
,
"n dimension"
)
// 128, 4096
.
insert
(
"k"
,
"64"
,
"k dimension"
)
// 64, 2048
.
insert
(
"k"
,
"64"
,
"k dimension"
)
// 64, 2048
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default"
)
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default"
)
.
insert
(
"b_layout"
,
"
R
"
,
"B tensor data layout - Row by default"
)
.
insert
(
"b_layout"
,
"
C
"
,
"B tensor data layout - Row by default"
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default"
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default"
)
.
insert
(
"stride_a"
,
"0"
,
"Tensor A stride"
)
.
insert
(
"stride_a"
,
"0"
,
"Tensor A stride"
)
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
.
insert
(
"stride_c"
,
"0"
,
"Tensor C stride"
)
.
insert
(
"stride_c"
,
"0"
,
"Tensor C stride"
)
.
insert
(
"v"
,
"
2
"
,
"0. No validation, 1. Validation on CPU, 2. Validation on GPU"
)
.
insert
(
"v"
,
"
1
"
,
"0. No validation, 1. Validation on CPU, 2. Validation on GPU"
)
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
...
...
example/ck_tile/18_flatmm/run_flatmm_example.inc
View file @
e889d086
...
@@ -415,8 +415,8 @@ int run_flatmm_example_with_layouts(int argc,
...
@@ -415,8 +415,8 @@ int run_flatmm_example_with_layouts(int argc,
// b_shuffle
// b_shuffle
{
{
std
::
ofstream
file
(
"ff_b_shuffle_host.txt"
);
std
::
ofstream
file
(
"ff_b_shuffle_host.txt"
);
int
X
=
static_cast
<
int
>
(
K
)
;
int
X
=
32
*
32
;
int
Y
=
static_cast
<
int
>
(
N
);
int
Y
=
static_cast
<
int
>
(
N
)
*
static_cast
<
int
>
(
M
)
/
X
;
file
<<
" [b_shuffle_host]: Row = "
<<
Y
<<
", Col = "
<<
X
<<
std
::
endl
;
file
<<
" [b_shuffle_host]: Row = "
<<
Y
<<
", Col = "
<<
X
<<
std
::
endl
;
for
(
int
y
=
0
;
y
<
Y
;
y
++
)
for
(
int
y
=
0
;
y
<
Y
;
y
++
)
...
...
include/ck_tile/ops/flatmm.hpp
View file @
e889d086
...
@@ -24,10 +24,10 @@
...
@@ -24,10 +24,10 @@
// kernel
// kernel
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
//
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
//
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
//
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
//
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp
View file @
e889d086
...
@@ -55,14 +55,13 @@ struct BlockFlatmmASmemBSmemCRegV1
...
@@ -55,14 +55,13 @@ struct BlockFlatmmASmemBSmemCRegV1
return
c_block_tensor
;
return
c_block_tensor
;
}
}
#if
1
#if
0
// C += A * B
// C += A * B
// template <typename CBlockTensor, typename ABlockWindow, typename BBlockWindow>
// template <typename CBlockTensor, typename ABlockWindow, typename BBlockWindow>
template
<
typename
ABlockWindow
>
template <typename ABlockWindow
, typename BBlockWindow
>
CK_TILE_DEVICE
void
operator
()(
const
ABlockWindow
&
a_block_window
CK_TILE_DEVICE void operator()(const ABlockWindow& a_block_window
, const BBlockWindow& b_block_window
#if FEIFEI_DEBUG
#if FEIFEI_DEBUG
,
,
const
BDataType
*
b_ptr
,
int* dbg_int,
int* dbg_int,
float* dbg_fp32,
float* dbg_fp32,
void* dbg_f168
void* dbg_f168
...
@@ -101,14 +100,12 @@ struct BlockFlatmmASmemBSmemCRegV1
...
@@ -101,14 +100,12 @@ struct BlockFlatmmASmemBSmemCRegV1
*/
*/
constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}];
constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}];
//
constexpr index_t NPerBlock = BBlockWindow{}.get_window_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindow{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}];
constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}];
/*
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
KPerBlock == BlockGemmShape::kK,
"wrong!");
"wrong!");
*/
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
using WG = remove_cvref_t<decltype(config.template at<0>())>;
...
@@ -117,11 +114,11 @@ struct BlockFlatmmASmemBSmemCRegV1
...
@@ -117,11 +114,11 @@ struct BlockFlatmmASmemBSmemCRegV1
constexpr index_t NWarp = config.template at<2>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
//
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
//
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iMWarp = get_warp_id() / NWarp;
...
@@ -133,6 +130,7 @@ struct BlockFlatmmASmemBSmemCRegV1
...
@@ -133,6 +130,7 @@ struct BlockFlatmmASmemBSmemCRegV1
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
...
@@ -151,7 +149,29 @@ struct BlockFlatmmASmemBSmemCRegV1
...
@@ -151,7 +149,29 @@ struct BlockFlatmmASmemBSmemCRegV1
// Warp loop in block:
// Warp loop in block:
constexpr index_t kIter = 0;
constexpr index_t kIter = 0;
constexpr index_t mIter = 0;
constexpr index_t mIter = 0;
const
auto
a_warp_tensor
=
load_tile
(
a_warp_windows
(
number
<
mIter
>
{})(
number
<
kIter
>
{}));
const auto a_warp_tensor = load_tile(a_warp_window_tmp);
#if FEIFEI_DEBUG
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[BLOCK ] WG::kM = %d, WG::kM = %d, WG::kK = %d, WG::kKPerThread = %d\n", WG::kM, WG::kN, WG::kK, WG::kKPerThread);
printf("[BLOCK ] MIterPerWarp = %d, NIterPerWarp = %d, KIterPerWarp = %d\n", MIterPerWarp, NIterPerWarp, KIterPerWarp);
}
// debug A lds read
int warp_tile_size_per_thread = a_warp_tensor.get_thread_buffer_size();
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[BLOCK ] warp_tile_size_per_thread = %d\n", warp_tile_size_per_thread);
}
for(auto i = 0; i < warp_tile_size_per_thread; i++)
{
dbg_f16[gid * DEBUG_CNT + i] = a_warp_tensor.get_thread_buffer()[i];
}
return ;
#endif
#if 1
#if 1
// feifei TODO: Implement gemm here
// feifei TODO: Implement gemm here
...
...
include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp
View file @
e889d086
...
@@ -141,7 +141,7 @@ struct FlatmmKernel
...
@@ -141,7 +141,7 @@ struct FlatmmKernel
struct
SplitKBatchOffset
struct
SplitKBatchOffset
{
{
__device__
SplitKBatchOffset
(
const
FlatmmKernelArgs
&
kargs
,
CK_TILE_DEVICE
SplitKBatchOffset
(
const
FlatmmKernelArgs
&
kargs
,
const
std
::
size_t
k_id
=
blockIdx
.
z
)
const
std
::
size_t
k_id
=
blockIdx
.
z
)
{
{
constexpr
auto
K1
=
TilePartitioner
::
BlockGemmShape
::
WarpTile
::
at
(
number
<
2
>
{});
constexpr
auto
K1
=
TilePartitioner
::
BlockGemmShape
::
WarpTile
::
at
(
number
<
2
>
{});
...
@@ -175,7 +175,42 @@ struct FlatmmKernel
...
@@ -175,7 +175,42 @@ struct FlatmmKernel
splitted_k
=
kargs
.
K
-
KRead
*
(
kargs
.
KBatch
-
1
);
splitted_k
=
kargs
.
K
-
KRead
*
(
kargs
.
KBatch
-
1
);
}
}
}
}
#if FEIFEI_DEBUG
CK_TILE_HOST
SplitKBatchOffset
(
const
FlatmmHostArgs
&
hargs
,
const
std
::
size_t
k_id
=
0
)
{
constexpr
auto
K1
=
TilePartitioner
::
BlockGemmShape
::
WarpTile
::
at
(
number
<
2
>
{});
const
index_t
K_t
=
hargs
.
k_batch
*
K1
;
const
index_t
KRead
=
(
hargs
.
K
+
K_t
-
1
)
/
K_t
*
K1
;
if
constexpr
(
std
::
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
a_k_split_offset
=
k_id
*
KRead
;
}
else
if
constexpr
(
std
::
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
a_k_split_offset
=
k_id
*
KRead
*
hargs
.
stride_A
;
}
if
constexpr
(
std
::
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
b_k_split_offset
=
k_id
*
KRead
*
hargs
.
stride_B
;
}
else
if
constexpr
(
std
::
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
b_k_split_offset
=
k_id
*
KRead
;
}
if
(
k_id
<
static_cast
<
uint32_t
>
(
hargs
.
k_batch
-
1
))
{
splitted_k
=
KRead
;
}
else
{
splitted_k
=
hargs
.
K
-
KRead
*
(
hargs
.
k_batch
-
1
);
}
}
#endif
index_t
a_k_split_offset
;
index_t
a_k_split_offset
;
index_t
b_k_split_offset
;
index_t
b_k_split_offset
;
index_t
splitted_k
;
// problem K after splitted
index_t
splitted_k
;
// problem K after splitted
...
@@ -362,6 +397,9 @@ struct FlatmmKernel
...
@@ -362,6 +397,9 @@ struct FlatmmKernel
return
make_tuple
(
a_tensor_view
,
b_tensor_view
,
c_tensor_view
);
return
make_tuple
(
a_tensor_view
,
b_tensor_view
,
c_tensor_view
);
}
}
#if 1
template
<
typename
TensorView
>
template
<
typename
TensorView
>
CK_TILE_DEVICE
static
auto
MakeGemmPadViews
(
const
TensorView
&
views
)
CK_TILE_DEVICE
static
auto
MakeGemmPadViews
(
const
TensorView
&
views
)
{
{
...
@@ -446,6 +484,118 @@ struct FlatmmKernel
...
@@ -446,6 +484,118 @@ struct FlatmmKernel
return
make_tuple
(
a_block_window
,
b_block_window
,
c_block_window
);
return
make_tuple
(
a_block_window
,
b_block_window
,
c_block_window
);
}
}
#else
template
<
typename
TensorView
>
CK_TILE_DEVICE
static
auto
MakeGemmPadViews
(
const
TensorView
&
views
)
{
const
auto
&
a_pad_view
=
[
&
]()
{
const
auto
&
a_tensor_view
=
views
.
at
(
I0
);
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
false
,
FlatmmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
KPerBlock
>
{},
number
<
TilePartitioner
::
MPerBlock
>
{}),
sequence
<
false
,
FlatmmPipeline
::
kPadM
>
{});
}
}();
const
auto
&
b_pad_view
=
[
&
]()
{
const
auto
&
b_tensor_view
=
views
.
at
(
I1
);
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
false
,
FlatmmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
KPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
sequence
<
false
,
FlatmmPipeline
::
kPadN
>
{});
}
}();
// TODO vector write in for C in ColMajor
const
auto
&
c_pad_view
=
[
&
]()
{
const
auto
&
c_tensor_view
=
views
.
at
(
I2
);
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
sequence
<
false
,
FlatmmPipeline
::
kPadN
>
{});
}
else
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
sequence
<
FlatmmPipeline
::
kPadM
,
false
>
{});
}
}();
return
make_tuple
(
a_pad_view
,
b_pad_view
,
c_pad_view
);
}
template
<
typename
PadView
>
CK_TILE_DEVICE
static
auto
MakeGemmTileWindows
(
const
PadView
&
views
,
const
index_t
i_m
,
const
index_t
i_n
)
{
const
auto
&
a_pad_view
=
views
.
at
(
I0
);
const
auto
&
b_pad_view
=
views
.
at
(
I1
);
const
auto
&
c_pad_view
=
views
.
at
(
I2
);
const
auto
&
a_block_window
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_m
,
0
});
}
else
{
return
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
KPerBlock
>
{},
number
<
TilePartitioner
::
MPerBlock
>
{}),
{
0
,
i_m
});
}
}();
const
auto
&
b_block_window
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_n
,
0
});
}
else
{
return
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
KPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
{
0
,
i_n
});
}
}();
auto
c_block_window
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
{
i_m
,
i_n
});
return
make_tuple
(
a_block_window
,
b_block_window
,
c_block_window
);
}
#endif
/**
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
* @brief Runs single GEMM problem cooperatively by whole workgroup.
...
@@ -477,20 +627,29 @@ struct FlatmmKernel
...
@@ -477,20 +627,29 @@ struct FlatmmKernel
#endif
#endif
)
)
{
{
#if FEIFEI_DEBUG
uint32_t
tidx
=
threadIdx
.
x
;
uint32_t
tidy
=
threadIdx
.
y
;
uint32_t
bidx
=
blockIdx
.
x
;
uint32_t
bidy
=
blockIdx
.
y
;
uint32_t
bdmx
=
blockDim
.
x
;
uint32_t
bdmy
=
blockDim
.
y
;
uint32_t
gdmx
=
gridDim
.
x
;
uint32_t
gdmy
=
gridDim
.
y
;
uint32_t
gid
=
((
bdmx
*
bdmy
)
*
gdmx
)
*
bidy
+
(
bdmx
*
bdmy
)
*
bidx
+
bdmx
*
tidy
+
tidx
;
half_t
*
dbg_f16
=
static_cast
<
half_t
*>
(
kargs
.
dbg_f168_ptr
);
#endif
// Create Flatmm tensor views, pad views and tile windows
// Create Flatmm tensor views, pad views and tile windows
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
<
DstInMemOp
>
(
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
<
DstInMemOp
>
(
a_ptr
,
b_shuffle_ptr
,
c_ptr
,
kargs
,
splitk_batch_offset
);
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
splitk_batch_offset
);
// Debug origin layout
// const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
// a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
const
auto
&
gemm_tile_windows
=
const
auto
&
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
#if FEIFEI_DEBUG
#if FEIFEI_DEBUG
////////////////////////////////////////////////////////
////////////////////////////////////////////////////////
const
auto
&
a_gemm_tensor_views
=
gemm_tensor_views_tuple
.
at
(
I0
);
// tensor_view
const
auto
&
a_gemm_tensor_views
=
gemm_tensor_views_tuple
.
at
(
I0
);
// tensor_view
...
@@ -533,39 +692,51 @@ struct FlatmmKernel
...
@@ -533,39 +692,51 @@ struct FlatmmKernel
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
splitk_batch_offset
.
splitted_k
);
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
splitk_batch_offset
.
splitted_k
);
const
auto
&
b_flat_tensor_view
=
[
&
]()
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_shuffle_ptr
,
make_tuple
(
kargs
.
N
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
FlatmmPipeline
::
GetVectorSizeB
()
>
{},
number
<
1
>
{});
}();
const
auto
&
b_flat_pad_view
=
[
&
]()
{
return
pad_tensor_view
(
b_flat_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
false
,
FlatmmPipeline
::
kPadK
>
{});
}();
const
auto
&
b_flat_block_window
=
make_tile_window
(
b_flat_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
block_idx_n
,
0
});
// Run GEMM cooperatively by whole workgroup.
// Run GEMM cooperatively by whole workgroup.
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
const
auto
&
c_block_tile
=
FlatmmPipeline
{}.
template
operator
()(
a_block_window
,
const
auto
&
c_block_tile
=
FlatmmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
b_
flat_
block_window
,
num_loop
,
num_loop
,
smem_ptr
smem_ptr
#if FEIFEI_DEBUG
#if FEIFEI_DEBUG
,
,
b_
ptr
,
b_
block_window
,
dbg_int
,
dbg_int
,
dbg_fp32
,
dbg_fp32
,
dbg_f168
dbg_f168
#endif
#endif
);
);
// feifei TODO: Un-comment bellow once pipeline() is implemented
#if 0
// Run Epilogue Pipeline
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2);
/*auto c_block_window = gemm_tile_windows.at(I2);
constexpr bool is_output_c_reg_transposed =
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
EpiloguePipeline::IsOutputTransposed() != FlatmmPipeline::IsTransposeC();
if constexpr((DstInMemOp == memory_operation_enum::set) || (sizeof(CDataType) > 2) ||
(FlatmmPipeline::VectorSizeC % 2 == 0 &&
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
is_output_c_reg_transposed))
{
{
EpiloguePipeline{}
printf("[PIPELN] C = %.3f\n", type_convert<float>(c_block_tile.get_thread_buffer()[0]));
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile);
}
}
#endif
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile, smem_ptr);*/
}
}
CK_TILE_DEVICE
void
operator
()(
FlatmmKernelArgs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
FlatmmKernelArgs
kargs
)
const
...
...
include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
e889d086
...
@@ -14,6 +14,10 @@ namespace ck_tile {
...
@@ -14,6 +14,10 @@ namespace ck_tile {
template
<
typename
Problem
,
typename
PipelinePolicy
=
UniversalFlatmmPipelineAgBgCrPolicy
>
// feifei TODO: add default policy
template
<
typename
Problem
,
typename
PipelinePolicy
=
UniversalFlatmmPipelineAgBgCrPolicy
>
// feifei TODO: add default policy
struct
FlatmmPipelineAGmemBGmemCRegV1
struct
FlatmmPipelineAGmemBGmemCRegV1
{
{
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
...
@@ -62,18 +66,22 @@ struct FlatmmPipelineAGmemBGmemCRegV1
...
@@ -62,18 +66,22 @@ struct FlatmmPipelineAGmemBGmemCRegV1
}
}
template
<
typename
ADramBlockWindowTmp
,
template
<
typename
ADramBlockWindowTmp
,
typename
B
Dram
BlockWindowTmp
,
typename
B
Flat
BlockWindowTmp
,
typename
AElementFunction
,
typename
AElementFunction
,
typename
BElementFunction
>
typename
BElementFunction
#if FEIFEI_DEBUG
,
typename
BDramBlockWindowTmp
#endif
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
CK_TILE_HOST_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
AElementFunction
&
a_element_func
,
const
B
Dram
BlockWindowTmp
&
b_dram_block_window_tmp
,
const
B
Flat
BlockWindowTmp
&
b_
flat_
dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
index_t
num_loop
,
void
*
p_smem
void
*
p_smem
#if FEIFEI_DEBUG
#if FEIFEI_DEBUG
,
,
const
BD
ataType
*
b_ptr
,
const
BD
ramBlockWindowTmp
&
b_dram_block_window_tmp
,
int
*
dbg_int
,
int
*
dbg_int
,
float
*
dbg_fp32
,
float
*
dbg_fp32
,
void
*
dbg_f168
void
*
dbg_f168
...
@@ -111,63 +119,107 @@ struct FlatmmPipelineAGmemBGmemCRegV1
...
@@ -111,63 +119,107 @@ struct FlatmmPipelineAGmemBGmemCRegV1
"wrong!"
);
"wrong!"
);
static_assert
(
kMPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
static_assert
(
kMPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kNPerBlock
==
B
Dram
BlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kNPerBlock
==
B
Flat
BlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kKPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
kKPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
"wrong!"
);
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
threadIdx
.
y
==
0
&&
blockIdx
.
y
==
0
)
{
printf
(
"[PIPELN] kMPerBlock = %d, winN = %d
\n
"
,
kMPerBlock
,
static_cast
<
int
>
(
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]));
printf
(
"[PIPELN] kNPerBlock = %d, winN = %d
\n
"
,
kNPerBlock
,
static_cast
<
int
>
(
BFlatBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]));
printf
(
"[PIPELN] kNPerBlock = %d, winN = %d
\n
"
,
kNPerBlock
,
static_cast
<
int
>
(
BDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]));
printf
(
"[PIPELN] kKPerBlock = %d, winN = %d
\n
"
,
kKPerBlock
,
static_cast
<
int
>
(
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]));
}
#if 1
#if 1
// feifei TODO: Implement gemm here
// feifei TODO: Implement gemm here
// Get block flatmm
// Get block flatmm
auto
block_flatmm
=
BlockFlatmm
();
// struct BlockFlatmmASmemBSmemCRegV1
auto
block_flatmm
=
BlockFlatmm
();
// struct BlockFlatmmASmemBSmemCRegV1
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
PipelinePolicy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
// A DRAM tile window for load
// A DRAM tile window for load
auto
a_copy_dram_window
=
auto
a_copy_dram_window
=
// tile_window_with_static_distribution
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
// from kernel gemm_pad_views
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
a_dram_block_window_tmp
.
get_window_origin
(),
PipelinePolicy
::
template
MakeADramTileDistribution
<
Problem
>());
PipelinePolicy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
// B DRAM tile window for load
auto
a_copy_lds_window
=
make_tile_window
(
auto
b_copy_dram_window
=
// tile_window_with_static_distribution
a_lds_block
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
// from kernel gemm_pad_views
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
PipelinePolicy
::
template
MakeBDramTileDistribution
<
Problem
>());
// A LDS tile for block GEMM
// B flat DRAM window for load
auto
a_lds_gemm_window
=
make_tile_window
(
auto
b_flat_distribution
=
PipelinePolicy
::
template
MakeBFlatDramTileDistribution
<
Problem
>();
a_lds_block
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
auto
b_flat_dram_window
=
// tile_window_with_static_distribution
make_tile_window
(
b_flat_dram_block_window_tmp
.
get_bottom_tensor_view
(),
// from kernel gemm_pad_views
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
BlockSize
>
{}
*
4
),
b_flat_dram_block_window_tmp
.
get_window_origin
(),
b_flat_distribution
);
// Prefetch -----------------------------------------------------------
// Prefetch -----------------------------------------------------------
// global read 0
// global read 0
auto
a_block_tile
=
load_tile
(
a_copy_dram_window
);
auto
a_block_tile
=
load_tile
(
a_copy_dram_window
);
auto
b_block_tile
=
load_tile
(
b_copy_dram_window
);
auto
b_flat_tile
=
load_tile
(
b_flat_dram_window
);
#if FEIFEI_DEBUG // debug A global load
#if FEIFEI_DEBUG
int
a_dim
=
a_block_tile
.
get_num_of_dimension
();
// debug A global load
int
a_sz
=
a_block_tile
.
get_thread_buffer_size
();
int
a_block_tile_size_per_thread
=
a_block_tile
.
get_thread_buffer_size
();
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
threadIdx
.
y
==
0
&&
blockIdx
.
y
==
0
)
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
threadIdx
.
y
==
0
&&
blockIdx
.
y
==
0
)
{
{
printf
(
"[PIPELN] a_
dim = %d, a_sz = %d
\n
"
,
a_dim
,
a_sz
);
printf
(
"[PIPELN] a_
block_tile_size_per_thread = %d
\n
"
,
a_block_tile_size_per_thread
);
}
}
for
(
auto
i
=
0
;
i
<
a_
sz
;
i
++
)
for
(
auto
i
=
0
;
i
<
a_
block_tile_size_per_thread
;
i
++
)
{
{
dbg_f16
[
gid
*
DEBUG_CNT
+
i
]
=
a_block_tile
.
get_thread_buffer
()[
i
];
dbg_f16
[
gid
*
DEBUG_CNT
+
i
]
=
a_block_tile
.
get_thread_buffer
()[
i
];
}
}
// debug B global load
int
b_block_tile_size_per_thread
=
b_block_tile
.
get_thread_buffer_size
();
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
threadIdx
.
y
==
0
&&
blockIdx
.
y
==
0
)
{
printf
(
"[PIPELN] b_block_tile_size_per_thread = %d
\n
"
,
b_block_tile_size_per_thread
);
}
for
(
auto
i
=
0
;
i
<
b_block_tile_size_per_thread
;
i
++
)
{
//dbg_f16[gid * DEBUG_CNT + i] = b_block_tile.get_thread_buffer()[i];
}
// debug flat B global load
int
b_flat_tile_size_per_thread
=
b_flat_tile
.
get_thread_buffer_size
();
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
threadIdx
.
y
==
0
&&
blockIdx
.
y
==
0
)
{
printf
(
"[PIPELN] b_flat_tile_size_per_thread = %d
\n
"
,
b_flat_tile_size_per_thread
);
}
for
(
auto
i
=
0
;
i
<
b_flat_tile_size_per_thread
;
i
++
)
{
//dbg_f16[gid * DEBUG_CNT + i + b_block_tile_size_per_thread + 4] = b_flat_tile.get_thread_buffer()[i];
}
return
nullptr
;
return
nullptr
;
#endif
#endif
#if 0
// move to 1
// move to 1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// A LDS tile window for store
auto a_copy_lds_window = make_tile_window( // tile_window_with_static_lengths
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window( // tile_window_with_static_lengths
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// LDS write 0
// LDS write 0
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
...
@@ -183,12 +235,26 @@ struct FlatmmPipelineAGmemBGmemCRegV1
...
@@ -183,12 +235,26 @@ struct FlatmmPipelineAGmemBGmemCRegV1
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
}
}
// B tile in LDS
constexpr index_t a_lds_block_space_size_aligned = integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * 16;
BDataType* p_b_lds = static_cast<BDataType*>(static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = PipelinePolicy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// B LDS tile window for store
auto b_copy_lds_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Loop ---------------------------------------------------------------
// Loop ---------------------------------------------------------------
// Do flatmm
// Do flatmm
block_flatmm
(
a_lds_gemm_window
block_sync_lds();
block_flatmm(a_lds_gemm_window, b_lds_gemm_window
#if FEIFEI_DEBUG
#if FEIFEI_DEBUG
,
,
b_ptr
,
dbg_int,
dbg_int,
dbg_fp32,
dbg_fp32,
dbg_f168
dbg_f168
...
@@ -198,6 +264,157 @@ struct FlatmmPipelineAGmemBGmemCRegV1
...
@@ -198,6 +264,157 @@ struct FlatmmPipelineAGmemBGmemCRegV1
// Tail ---------------------------------------------------------------
// Tail ---------------------------------------------------------------
return nullptr;
return nullptr;
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// A tile in LDS
/*ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc =
PipelinePolicy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// A LDS tile window for store
auto a_copy_lds_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B LDS tile window for store
auto b_copy_lds_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Block GEMM
auto block_gemm = BlockFlatmm();
// Acc register tile
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
// prefetch
// global read 0
//auto a_block_tile = load_tile(a_copy_dram_window);
//auto b_block_tile = load_tile(b_copy_dram_window);
{
// move to 1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
PipelinePolicy::template MakeShuffledARegBlockDescriptor<Problem>());
shuffle_tile(a_shuffle_tmp, a_block_tile);
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
store_tile(a_copy_lds_window, a_block_tile_tmp);
}
else
{
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
}
// LDS write 0
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
PipelinePolicy::template MakeShuffledBRegBlockDescriptor<Problem>());
shuffle_tile(b_shuffle_tmp, b_block_tile);
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp);
store_tile(b_copy_lds_window, b_block_tile_tmp);
}
else
{
store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile));
}
}
index_t iCounter = num_loop - 1;
while(iCounter > 0)
{
// global read i + 1
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
block_sync_lds();
// GEMM i
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
// move to i + 2
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// LDS write i + 1
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
// LDS write i + 1
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
PipelinePolicy::template MakeShuffledBRegBlockDescriptor<Problem>());
shuffle_tile(b_shuffle_tmp_loop, b_block_tile);
store_tile(b_copy_lds_window,
tile_elementwise_in(b_element_func, b_shuffle_tmp_loop));
}
else
{
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
}
iCounter--;
}
// tail
{
block_sync_lds();
// GEMM num_loop - 1
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
int c_block_tile_size_per_thread = c_block_tile.get_thread_buffer_size();
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPELN] c_block_tile_size_per_thread = %d\n", c_block_tile_size_per_thread);
}
for(auto i = 0; i < c_block_tile_size_per_thread; i++)
{
//dbg_fp32[gid * DEBUG_CNT + i] = c_block_tile.get_thread_buffer()[i];
dbg_fp32[gid * DEBUG_CNT + i] = 3.12f;
c_block_tile.get_thread_buffer()[i] = 1.23f;
}
return c_block_tile;*/
////////////////////////////////////////////////////////////////////////////////////////////////////
#else
#else
// A tile in LDS
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
...
@@ -352,14 +569,18 @@ struct FlatmmPipelineAGmemBGmemCRegV1
...
@@ -352,14 +569,18 @@ struct FlatmmPipelineAGmemBGmemCRegV1
#endif
#endif
}
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
>
template
<
typename
ADramBlockWindowTmp
,
typename
BFlatBlockWindowTmp
#if FEIFEI_DEBUG
,
typename
BDramBlockWindowTmp
#endif
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
B
Dram
BlockWindowTmp
&
b_dram_block_window_tmp
,
const
B
Flat
BlockWindowTmp
&
b_
flat_
dram_block_window_tmp
,
index_t
num_loop
,
index_t
num_loop
,
void
*
p_smem
void
*
p_smem
#if FEIFEI_DEBUG
#if FEIFEI_DEBUG
,
,
const
BD
ataType
*
b_ptr
,
const
BD
ramBlockWindowTmp
&
b_dram_block_window_tmp
,
int
*
dbg_int
,
int
*
dbg_int
,
float
*
dbg_fp32
,
float
*
dbg_fp32
,
void
*
dbg_f168
void
*
dbg_f168
...
@@ -369,13 +590,13 @@ struct FlatmmPipelineAGmemBGmemCRegV1
...
@@ -369,13 +590,13 @@ struct FlatmmPipelineAGmemBGmemCRegV1
return
operator
()(
return
operator
()(
a_dram_block_window_tmp
,
a_dram_block_window_tmp
,
[](
const
ADataType
&
a
)
{
return
a
;
},
[](
const
ADataType
&
a
)
{
return
a
;
},
b_dram_block_window_tmp
,
b_
flat_
dram_block_window_tmp
,
[](
const
BDataType
&
b
)
{
return
b
;
},
[](
const
BDataType
&
b
)
{
return
b
;
},
num_loop
,
num_loop
,
p_smem
p_smem
#if FEIFEI_DEBUG
#if FEIFEI_DEBUG
,
,
b_
ptr
,
b_
dram_block_window_tmp
,
dbg_int
,
dbg_int
,
dbg_fp32
,
dbg_fp32
,
dbg_f168
dbg_f168
...
...
include/ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
e889d086
...
@@ -227,15 +227,24 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
...
@@ -227,15 +227,24 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
}
}
else
else
{
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
// dwordx4 load A elem cnt
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
// threads cnt in K dim
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
// threads cnt in M dim (per wave)
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
{
{
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
// wave cnt in M dim (per block)
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
// load repeat times in M dim
#if FEIFEI_DEBUG
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
threadIdx
.
y
==
0
&&
blockIdx
.
y
==
0
)
{
printf
(
"[PIPELN] MakeADramTileDistribution():
\n
"
);
printf
(
"[PIPELN] MPerBlock = %d, KPerBlock = %d, AperBlock = %d
\n
"
,
MPerBlock
,
KPerBlock
,
MPerBlock
*
KPerBlock
);
printf
(
"[PIPELN] BlockSize = %d, warp_size = %d, VectorLoadSize = %d
\n
"
,
BlockSize
,
get_warp_size
(),
Problem
::
VectorLoadSize
);
printf
(
"[PIPELN] K1 = %d, K0 = %d, M2 = %d, M1 = %d, M0 = %d
\n
"
,
K1
,
K0
,
M2
,
M1
,
M0
);
}
#endif
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
...
@@ -310,18 +319,25 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
...
@@ -310,18 +319,25 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
}
}
else
else
{
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
// dwordx4 load B elem cnt
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
// threads cnt in K dim
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// threads cnt in N dim (per wave)
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
N2
*
K0
)
==
0
)
if
constexpr
(
get_warp_size
()
%
(
N2
*
K0
)
==
0
)
{
{
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
// wave cnt in N dim (per block)
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
// load repeat times in N dim
#if FEIFEI_DEBUG
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
threadIdx
.
y
==
0
&&
blockIdx
.
y
==
0
)
{
printf
(
"[PIPELN] MakeBDramTileDistribution():
\n
"
);
printf
(
"[PIPELN] NPerBlock = %d, KPerBlock = %d, BperBlock = %d
\n
"
,
NPerBlock
,
KPerBlock
,
NPerBlock
*
KPerBlock
);
printf
(
"[PIPELN] BlockSize = %d, warp_size = %d, VectorLoadSize = %d
\n
"
,
BlockSize
,
get_warp_size
(),
Problem
::
VectorLoadSize
);
printf
(
"[PIPELN] K1 = %d, K0 = %d, N2 = %d, N1 = %d, N0 = %d
\n
"
,
K1
,
K0
,
N2
,
N1
,
N0
);
}
#endif
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
...
@@ -347,6 +363,92 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
...
@@ -347,6 +363,92 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
}
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBFlatDramTileDistribution
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
KPack
=
GetVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
>
();
static_assert
(
KPack
%
K3
==
0
);
constexpr
index_t
K2
=
KPack
/
K3
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
KLoad
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
// dwordx4 load B elem cnt
constexpr
index_t
KThdInBlk
=
64
;
constexpr
index_t
KBlkInTile
=
1
;
constexpr
index_t
KRepeat
=
1
;
constexpr
index_t
NLoad
=
1
;
// dwordx4 load B elem cnt
constexpr
index_t
NThdInBlk
=
1
;
constexpr
index_t
NBlkInTile
=
4
;
constexpr
index_t
NRepeat
=
1
;
#if FEIFEI_DEBUG
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
threadIdx
.
y
==
0
&&
blockIdx
.
y
==
0
)
{
printf
(
"[PIPELN] MakeBFlatDramTileDistribution():
\n
"
);
printf
(
"[PIPELN] NPerBlock = %d, KPerBlock = %d, BperBlock = %d
\n
"
,
NPerBlock
,
KPerBlock
,
NPerBlock
*
KPerBlock
);
printf
(
"[PIPELN] BlockSize = %d, warp_size = %d, VectorLoadSize = %d
\n
"
,
BlockSize
,
get_warp_size
(),
Problem
::
VectorLoadSize
);
printf
(
"[PIPELN] NRepeat = %d, NBlkInTile = %d, NThdInBlk = %d, NLoad = %d
\n
"
,
NRepeat
,
NBlkInTile
,
NThdInBlk
,
NLoad
);
printf
(
"[PIPELN] KRepeat = %d, KBlkInTile = %d, KThdInBlk = %d, KLoad = %d
\n
"
,
KRepeat
,
KBlkInTile
,
KThdInBlk
,
KLoad
);
}
#endif
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
NRepeat
,
NBlkInTile
,
NThdInBlk
,
NLoad
>
,
sequence
<
KRepeat
,
KBlkInTile
,
KThdInBlk
,
KLoad
>>
,
// first dim
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
2
,
2
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
3
>>
{});
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockDescriptor
()
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment