Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a16bb52d
"examples/vscode:/vscode.git/clone" did not exist on "b9b891621e8ed5729761cc6a31b23072315d2df0"
Commit
a16bb52d
authored
Oct 30, 2024
by
carlushuang
Browse files
Merge remote-tracking branch 'origin/develop' into ck_tile/layernorm_fusion
parents
3c305fee
24d996aa
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1156 additions
and
483 deletions
+1156
-483
example/ck_tile/03_gemm/CMakeLists.txt
example/ck_tile/03_gemm/CMakeLists.txt
+2
-2
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+45
-321
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+26
-6
example/ck_tile/03_gemm/gemm_mem_pipeline.cpp
example/ck_tile/03_gemm/gemm_mem_pipeline.cpp
+188
-0
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+217
-0
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+1
-0
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+26
-1
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+11
-6
include/ck_tile/core/utility/literals.hpp
include/ck_tile/core/utility/literals.hpp
+22
-0
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+21
-39
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+2
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
...ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
+15
-15
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+62
-61
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
+14
-10
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+413
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
...le/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
+71
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+12
-12
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+4
-6
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
+3
-3
No files found.
example/ck_tile/03_gemm/CMakeLists.txt
View file @
a16bb52d
set
(
CMAKE_BUILD_TYPE Debug
)
add_executable
(
tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp
)
add_executable
(
tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp
)
add_executable
(
tile_example_gemm_mem_pipeline EXCLUDE_FROM_ALL gemm_mem_pipeline.cpp
)
\ No newline at end of file
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
a16bb52d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_basic.hpp"
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <cstring>
#include <cstring>
...
@@ -10,51 +9,48 @@
...
@@ -10,51 +9,48 @@
#include <string>
#include <string>
#include <tuple>
#include <tuple>
auto
create_args
(
int
argc
,
char
*
argv
[])
#include "ck_tile/ops/epilogue.hpp"
{
#include "ck_tile/ops/gemm.hpp"
ck_tile
::
ArgParser
arg_parser
;
#include "ck_tile/host.hpp"
arg_parser
.
insert
(
"b"
,
"1"
,
"batch size"
)
#include "gemm_basic.hpp"
.
insert
(
"m"
,
"1024"
,
"m dimension"
)
.
insert
(
"n"
,
"2048"
,
"n dimension"
)
.
insert
(
"k"
,
"64"
,
"k dimension"
)
.
insert
(
"stride_a"
,
"0"
,
"Tensor A stride"
)
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
.
insert
(
"stride_c"
,
"0"
,
"Tensor C stride"
)
.
insert
(
"v"
,
"2"
,
"0. No validation, 1. Validation on CPU, 2. Validation on GPU"
)
.
insert
(
"e"
,
"1e-5"
,
"Absolute error tolerance"
)
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"10"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
template
<
typename
LayoutA
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
typename
LayoutB
,
typename
LayoutC
,
typename
PipelineProblem
,
typename
GemmPipeline
,
typename
GemmShape
>
float
gemm_calc
(
const
gemm_basic_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_calc
(
const
gemm_basic_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadC
=
true
;
constexpr
bool
kTilePermute
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
// This part comes from the Codegen
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
// layout.
constexpr
bool
CShuffleEpilogue
=
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
LayoutC
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
CShuffleEpilogue
,
...
@@ -70,14 +66,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -70,14 +66,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
TilePartitioner
::
kN
>>
,
TilePartitioner
::
kN
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>>
;
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadA
,
kPadB
,
kPadC
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
<
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
Codegen
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_b
,
args
.
p_c
,
args
.
p_c
,
args
.
epsilon
,
args
.
M
,
args
.
M
,
args
.
N
,
args
.
N
,
args
.
K
,
args
.
K
,
...
@@ -88,299 +91,20 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -88,299 +91,20 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
float
ave_time
=
ck_tile
::
launch_kernel
(
if
(
s
.
log_level_
>
0
)
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
}
template
<
typename
DataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
,
typename
PipelineProblem
,
typename
GemmPipeline
,
typename
GemmShape
>
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_buf
,
ck_tile
::
DeviceMem
&
b_buf
,
ck_tile
::
DeviceMem
&
c_buf
,
const
ck_tile
::
ArgParser
&
arg_parser
)
{
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
if
(
data_type
!=
DataTypeTraits
<
DataType
>::
name
)
{
std
::
cerr
<<
"Data type mismatch: expected "
<<
DataTypeTraits
<
DataType
>::
name
<<
", got "
<<
data_type
<<
std
::
endl
;
return
-
1
;
// Or handle the error appropriately
}
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
ck_tile
::
index_t
batch_size
=
arg_parser
.
get_int
(
"b"
);
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
ck_tile
::
index_t
stride_a
=
arg_parser
.
get_int
(
"stride_a"
);
ck_tile
::
index_t
stride_b
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_c
=
arg_parser
.
get_int
(
"stride_c"
);
gemm_basic_args
args
;
args
.
p_a
=
a_buf
.
GetDeviceBuffer
();
args
.
p_b
=
b_buf
.
GetDeviceBuffer
();
args
.
p_c
=
c_buf
.
GetDeviceBuffer
();
args
.
epsilon
=
epsilon
;
args
.
kbatch
=
batch_size
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
// Only set stride_M and stride_N if they are non-zero and not equal to K.
if
(
stride_a
!=
0
)
{
args
.
stride_A
=
stride_a
;
}
else
{
args
.
stride_A
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
LayoutA
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
M
;
}
else
{
return
K
;
}
}();
}
if
(
stride_b
!=
0
)
{
args
.
stride_B
=
stride_b
;
}
else
{
{
args
.
stride_B
=
[
&
]()
{
std
::
cout
<<
"Launching kernel with args:"
if
constexpr
(
std
::
is_same_v
<
LayoutB
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
{
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
return
N
;
<<
std
::
endl
;
}
else
{
return
K
;
}
}();
}
}
if
(
stride_c
!=
0
)
float
ave_time
=
ck_tile
::
launch_kernel
(
{
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
args
.
stride_C
=
stride_c
;
}
else
{
args
.
stride_C
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
LayoutC
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
M
;
}
else
{
return
N
;
}
}();
}
float
ave_time
=
gemm_calc
<
LayoutA
,
LayoutB
,
LayoutC
,
PipelineProblem
,
GemmPipeline
,
GemmShape
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
});
std
::
size_t
num_byte
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
N
*
K
+
sizeof
(
CDataType
)
*
M
*
N
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
"The overall perfomance of the GEMM with "
<<
"["
<<
data_type
<<
"]"
<<
"batch size: "
<<
batch_size
<<
". m:"
<<
M
<<
", n:"
<<
N
<<
", k:"
<<
K
<<
" is:
\n
"
;
std
::
cout
<<
"Running time: "
<<
ave_time
<<
"ms, Throughput "
<<
gb_per_sec
<<
"GB/s
\n
"
<<
std
::
flush
;
return
ave_time
;
return
ave_time
;
}
}
int
main
(
int
argc
,
char
*
argv
[])
#include "run_gemm_example.inc"
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
// The Matrix Multiplication goes with Matrix A (M, K), Matrix B (N, K) = Matrix C (M, N).
using
matrix_a_layout
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
matrix_b_layout
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
matrix_c_layout
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
// host verify
std
::
vector
<
int
>
a_dimensions
=
(
std
::
is_same_v
<
matrix_a_layout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
?
std
::
vector
<
int
>
{
M
,
K
}
:
std
::
vector
<
int
>
{
K
,
M
};
std
::
vector
<
int
>
b_dimensions
=
(
std
::
is_same_v
<
matrix_b_layout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
std
::
vector
<
int
>
{
N
,
K
}
:
std
::
vector
<
int
>
{
K
,
N
};
std
::
vector
<
int
>
c_dimensions
=
(
std
::
is_same_v
<
matrix_c_layout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
?
std
::
vector
<
int
>
{
M
,
N
}
:
std
::
vector
<
int
>
{
N
,
M
};
ck_tile
::
HostTensor
<
ADataType
>
a_host
(
a_dimensions
);
ck_tile
::
HostTensor
<
BDataType
>
b_host
(
b_dimensions
);
ck_tile
::
HostTensor
<
CDataType
>
c_host_ref
(
c_dimensions
);
ck_tile
::
HostTensor
<
CDataType
>
c_host_dev
(
c_dimensions
);
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_host
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_host
);
ck_tile
::
DeviceMem
a_buf
(
a_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_buf
(
b_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_buf
(
c_host_dev
.
get_element_space_size_in_bytes
());
a_buf
.
ToDevice
(
a_host
.
data
());
b_buf
.
ToDevice
(
b_host
.
data
());
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadC
=
true
;
// This part comes from the Codegen
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadA
,
kPadB
,
kPadC
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
<
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
invoke_gemm
<
ck_tile
::
half_t
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
,
CodegenPipelineProblem
,
CodegenGemmPipeline
,
CodegenGemmShape
>
(
a_buf
,
b_buf
,
c_buf
,
arg_parser
);
c_buf
.
FromDevice
(
c_host_dev
.
data
());
bool
pass_cpu
=
true
;
if
(
arg_parser
.
get_int
(
"v"
)
==
1
)
{
// ToDo: Will Add the Element Op (bias) verification in the future.
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
(
a_host
,
b_host
,
c_host_ref
);
pass_cpu
=
ck_tile
::
check_err
(
c_host_dev
,
c_host_ref
);
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass_cpu
?
"correct"
:
"fail"
)
<<
std
::
flush
;
}
bool
pass_gpu
=
true
;
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
ck_tile
::
index_t
stride_a
=
arg_parser
.
get_int
(
"stride_a"
);
ck_tile
::
index_t
stride_b
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_c
=
arg_parser
.
get_int
(
"stride_c"
);
if
(
stride_a
==
0
)
{
if
constexpr
(
std
::
is_same_v
<
matrix_a_layout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
stride_a
=
M
;
}
else
{
stride_a
=
K
;
}
}
if
(
stride_b
==
0
)
{
if
constexpr
(
std
::
is_same_v
<
matrix_b_layout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
stride_b
=
N
;
}
else
{
stride_b
=
K
;
}
}
if
(
stride_c
==
0
)
{
if
constexpr
(
std
::
is_same_v
<
matrix_c_layout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
stride_c
=
M
;
}
else
{
stride_c
=
N
;
}
}
ck_tile
::
HostTensor
<
CDataType
>
c_host_gpu_ref
(
c_dimensions
);
ck_tile
::
DeviceMem
c_gpu_buf
(
c_host_gpu_ref
.
get_element_space_size_in_bytes
());
ck_tile
::
reference_gemm_gpu
<
ADataType
,
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
BDataType
,
AccDataType
,
CDataType
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
(
a_buf
,
b_buf
,
c_gpu_buf
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
c_buf
.
FromDevice
(
c_host_gpu_ref
.
data
());
pass_gpu
=
ck_tile
::
check_err
(
c_host_dev
,
c_host_gpu_ref
);
std
::
cout
<<
"The GPU veification result is: "
<<
(
pass_gpu
?
"correct"
:
"fail"
)
<<
std
::
flush
;
}
std
::
cout
<<
std
::
endl
<<
std
::
flush
;
return
!
pass_gpu
;
}
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
a16bb52d
...
@@ -4,12 +4,10 @@
...
@@ -4,12 +4,10 @@
#pragma once
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include <string>
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
GemmBasicTypeConfig
;
struct
GemmBasicTypeConfig
;
...
@@ -20,7 +18,7 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
...
@@ -20,7 +18,7 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
using
ADataType
=
ck_tile
::
half_t
;
using
ADataType
=
ck_tile
::
half_t
;
using
BDataType
=
ck_tile
::
half_t
;
using
BDataType
=
ck_tile
::
half_t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
CDataType
=
ck_tile
::
half_t
;
// type convert
using
CDataType
=
ck_tile
::
half_t
;
// ToDo: Add more bias config to support different categories of GEMM.
// ToDo: Add more bias config to support different categories of GEMM.
};
};
...
@@ -58,7 +56,6 @@ struct gemm_basic_args
...
@@ -58,7 +56,6 @@ struct gemm_basic_args
const
void
*
p_a
;
const
void
*
p_a
;
const
void
*
p_b
;
const
void
*
p_b
;
void
*
p_c
;
void
*
p_c
;
float
epsilon
;
ck_tile
::
index_t
kbatch
;
ck_tile
::
index_t
kbatch
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
ck_tile
::
index_t
N
;
...
@@ -68,5 +65,28 @@ struct gemm_basic_args
...
@@ -68,5 +65,28 @@ struct gemm_basic_args
ck_tile
::
index_t
stride_C
;
ck_tile
::
index_t
stride_C
;
};
};
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"b"
,
"1"
,
"batch size"
)
.
insert
(
"m"
,
"3840"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"k"
,
"2048"
,
"k dimension"
)
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default"
)
.
insert
(
"b_layout"
,
"R"
,
"B tensor data layout - Row by default"
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default"
)
.
insert
(
"stride_a"
,
"0"
,
"Tensor A stride"
)
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
.
insert
(
"stride_c"
,
"0"
,
"Tensor C stride"
)
.
insert
(
"v"
,
"2"
,
"0. No validation, 1. Validation on CPU, 2. Validation on GPU"
)
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
// host API
// host API
float
gemm_calc
(
gemm_basic_args
args
,
const
ck_tile
::
stream_config
&
s
);
float
gemm_calc
(
gemm_basic_args
args
,
const
ck_tile
::
stream_config
&
s
);
example/ck_tile/03_gemm/gemm_mem_pipeline.cpp
0 → 100644
View file @
a16bb52d
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <sstream>
#include <string>
#include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
gemm_calc
(
const
gemm_basic_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// ToDo: This will be modified by the codegen code later.
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadC
=
true
;
constexpr
int
kBlockPerCu
=
1
;
// ===============================================
using
GemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
false
,
kPadC
>>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadA
,
kPadB
,
kPadC
,
ALayout
,
BLayout
,
CLayout
>
;
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
float
ave_time
{
0
};
const
auto
Run
=
[
&
](
const
auto
has_hot_loop_
,
const
auto
tail_number_
)
{
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
using
GemmPipeline
=
ck_tile
::
GemmPipelineAgBgCrMem
<
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
,
has_hot_loop_v
,
tail_number_v
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_c
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
}
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
};
if
(
has_hot_loop
)
{
// Tail pipeline One to Seven
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
One
>
{});
}
else
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
2
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Two
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
3
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Three
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Three
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
4
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Four
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Four
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
5
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Five
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Five
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
6
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Six
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Six
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
7
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Seven
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Seven
>
{});
}
}
}
else
{
// Tail number always Full - #PrefetchStages
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
Run
(
ck_tile
::
bool_constant
<
false
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
else
{
std
::
ostringstream
err
;
err
<<
"When there's no hot loop, this tail number
\"
"
<<
tail_num
<<
"
\"
is not supported! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
}
return
ave_time
;
}
#include "run_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
example/ck_tile/03_gemm/run_gemm_example.inc
0 → 100644
View file @
a16bb52d
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
c_m_n_dev_buf
,
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
,
ck_tile
::
index_t
K
,
ck_tile
::
index_t
stride_A
,
ck_tile
::
index_t
stride_B
,
ck_tile
::
index_t
stride_C
,
ck_tile
::
index_t
kbatch
,
int
n_warmup
,
int
n_repeat
)
{
gemm_basic_args
args
;
args
.
p_a
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
p_b
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
p_c
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
kbatch
=
kbatch
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
args
.
stride_A
=
stride_A
;
args
.
stride_B
=
stride_B
;
args
.
stride_C
=
stride_C
;
float
ave_time
=
gemm_calc
<
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
std
::
string
op_name
{
"Gemm
{
MemBoundPipeline}"
}
;
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "
Run
" << op_name << "
kernel
with
M
=
" << M << "
N
=
" << N << "
K
=
" << K
<< "
StrideA
=
" << stride_A << "
StrideB
=
" << stride_B << "
StrideC
=
" << stride_C
<< "
:
" << ave_time << "
ms
,
" << tflops << "
TFlops
,
" << gb_per_sec << "
GB
/
s
,
"
<< std::endl;
return ave_time;
}
template <typename ALayout, typename BLayout, typename CLayout>
int run_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
ck_tile::index_t M = arg_parser.get_int("
m
");
ck_tile::index_t N = arg_parser.get_int("
n
");
ck_tile::index_t K = arg_parser.get_int("
k
");
ck_tile::index_t stride_A = arg_parser.get_int("
stride_a
");
ck_tile::index_t stride_B = arg_parser.get_int("
stride_b
");
ck_tile::index_t stride_C = arg_parser.get_int("
stride_c
");
ck_tile::index_t batch_size = arg_parser.get_int("
b
");
int n_warmup = arg_parser.get_int("
warmup
");
int n_repeat = arg_parser.get_int("
repeat
");
using namespace ck_tile::literals;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride = [](std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if(stride == 0)
{
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
stride_A = f_get_default_stride(M, K, stride_A, a_layout);
stride_B = f_get_default_stride(K, N, stride_B, b_layout);
stride_C = f_get_default_stride(M, N, stride_C, CLayout{});
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, a_layout));
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, b_layout));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
// TODO: add different init types
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
batch_size,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
if(arg_parser.get_int("
v
") == 1)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_ref);
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
std::cout << "
The
CPU
veification
result
is
:
" << (pass ? "
correct
" : "
fail
") << std::endl;
}
else if(arg_parser.get_int("
v
") == 2)
{
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero();
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C);
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);
std::cout << "
The
GPU
veification
result
is
:
" << (pass ? "
correct
" : "
fail
") << std::endl;
}
return pass;
}
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string a_layout = arg_parser.get_str("
a_layout
");
std::string b_layout = arg_parser.get_str("
b_layout
");
if(a_layout == "
R
" && b_layout == "
R
")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "
R
" && b_layout == "
C
")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "
C
" && b_layout == "
C
")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
}
else if(a_layout == "
C
" && b_layout == "
R
")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
}
else
{
throw std::runtime_error("
Unsupported
data
layout
configuration
for
A
,
B
and
C
tensors
!
");
}
}
include/ck_tile/core.hpp
View file @
a16bb52d
...
@@ -57,6 +57,7 @@
...
@@ -57,6 +57,7 @@
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/random.hpp"
...
...
include/ck_tile/core/tensor/load_tile.hpp
View file @
a16bb52d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -46,6 +46,31 @@ CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
...
@@ -46,6 +46,31 @@ CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
return
tile_window
.
load
(
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
}
template
<
typename
DistributedTensor_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
DistributedTensor_
&
dst_tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
dst_tile
,
bool_constant
<
oob_conditional_check
>
{});
}
/**
* @brief Loads a tile of data using inline assembly.
*
* @note Bare in mind that loading data this way, you have to manually initialize your
* thread buffer and synchronize load afterwards in order to make sure it's done before
* using loaded data from registers
* @see `tile_window_with_static_distribution::init_raw()` and `buffer_view.hpp`
* @see `buffer_load_fence()`
*/
template
<
typename
T
,
template
<
typename
T
,
typename
BottomTensorView_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
a16bb52d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -290,15 +290,22 @@ struct tile_window_with_static_distribution
...
@@ -290,15 +290,22 @@ struct tile_window_with_static_distribution
CK_TILE_DEVICE
auto
load
(
number
<
i_access_unsupport_
>
=
{},
CK_TILE_DEVICE
auto
load
(
number
<
i_access_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
using
Traits
=
load_store_traits
;
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
load
(
dst_tensor
,
bool_constant
<
oob_conditional_check
>
{});
return
dst_tensor
;
}
template
<
typename
DistributedTensor
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
DistributedTensor
&
dst_tensor
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
Traits
=
load_store_traits
;
using
vector_t
=
typename
Traits
::
vector_t
;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
/// TODO: use structure binding (to be captured later) if compiled in C++20
...
@@ -353,8 +360,6 @@ struct tile_window_with_static_distribution
...
@@ -353,8 +360,6 @@ struct tile_window_with_static_distribution
}
}
});
});
});
});
return
dst_tensor
;
}
}
template
<
typename
DstTile
,
template
<
typename
DstTile
,
...
...
include/ck_tile/core/utility/literals.hpp
0 → 100644
View file @
a16bb52d
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
namespace
ck_tile
{
namespace
literals
{
// [P0330] Literal Suffix for (signed) size_t (C++23)
// ref: https://wg21.link/p0330r8
inline
constexpr
std
::
size_t
operator
""
_uz
(
unsigned
long
long
size
)
{
return
static_cast
<
std
::
size_t
>
(
size
);
}
inline
constexpr
std
::
size_t
operator
""
_zu
(
unsigned
long
long
size
)
{
return
static_cast
<
std
::
size_t
>
(
size
);
}
}
// namespace literals
}
// namespace ck_tile
include/ck_tile/host/reference/reference_gemm.hpp
View file @
a16bb52d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <cstdlib>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include <thread>
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -14,55 +15,36 @@ template <typename ADataType,
...
@@ -14,55 +15,36 @@ template <typename ADataType,
typename
BDataType
,
typename
BDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
,
typename
AElementOp
=
ck_tile
::
identity
,
typename
AElementOp
=
ck_tile
::
identity
,
typename
BElementOp
=
ck_tile
::
identity
,
typename
BElementOp
=
ck_tile
::
identity
,
typename
ACCElementOp
=
ck_tile
::
identity
>
typename
ACCElementOp
=
ck_tile
::
identity
>
CK_TILE_HOST
void
reference_gemm
(
const
HostTensor
<
ADataType
>&
a_m_k
,
CK_TILE_HOST
void
reference_gemm
(
const
HostTensor
<
ADataType
>&
a_m_k
,
const
HostTensor
<
BDataType
>&
b_
n_k
,
const
HostTensor
<
BDataType
>&
b_
k_n
,
HostTensor
<
CDataType
>&
c_m_n
,
HostTensor
<
CDataType
>&
c_m_n
,
const
AElementOp
&
a_element_op
=
{},
const
AElementOp
&
a_element_op
=
{},
const
BElementOp
&
b_element_op
=
{},
const
BElementOp
&
b_element_op
=
{},
const
ACCElementOp
&
acc_element_op
=
{})
const
ACCElementOp
&
acc_element_op
=
{})
{
{
const
int
N
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
const
std
::
size_t
M
=
a_m_k
.
get_length
(
0
);
?
b_n_k
.
mDesc
.
get_lengths
()[
0
]
const
std
::
size_t
N
=
b_k_n
.
get_length
(
1
);
:
b_n_k
.
mDesc
.
get_lengths
()[
1
];
const
std
::
size_t
K
=
a_m_k
.
get_length
(
1
);
const
int
K
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_m_k
.
mDesc
.
get_lengths
()[
1
]
auto
f_mn
=
[
&
](
auto
m
,
auto
n
)
{
:
a_m_k
.
mDesc
.
get_lengths
()[
0
];
AccDataType
v_acc
=
0
;
const
int
M
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_m_k
.
mDesc
.
get_lengths
()[
0
]
for
(
std
::
size_t
k
=
0
;
k
<
K
;
++
k
)
:
a_m_k
.
mDesc
.
get_lengths
()[
1
];
auto
f
=
[
&
](
auto
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
AccDataType
v_acc
=
0
;
ADataType
v_a
=
a_element_op
(
a_m_k
(
m
,
k
));
BDataType
v_b
=
b_element_op
(
b_k_n
(
k
,
n
));
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v_acc
+=
ADataType
v_a
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
?
a_element_op
(
a_m_k
(
m
,
k
))
:
a_element_op
(
a_m_k
(
k
,
m
));
BDataType
v_b
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
b_element_op
(
b_n_k
(
n
,
k
))
:
b_element_op
(
b_n_k
(
k
,
n
));
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
}
CDataType
&
c_ref
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
c_m_n
(
m
,
n
)
:
c_m_n
(
n
,
m
);
c_ref
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
}
}
c_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
};
};
make_ParallelTensorFunctor
(
f
,
M
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f
_mn
,
M
,
N
)(
std
::
thread
::
hardware_concurrency
());
}
}
template
<
typename
ADataType
,
template
<
typename
ADataType
,
...
...
include/ck_tile/ops/gemm.hpp
View file @
a16bb52d
...
@@ -24,6 +24,8 @@
...
@@ -24,6 +24,8 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
View file @
a16bb52d
...
@@ -32,7 +32,7 @@ struct BlockGemmARegBGmemCRegV1
...
@@ -32,7 +32,7 @@ struct BlockGemmARegBGmemCRegV1
BlockGemmProblem
<
ADataType
,
BDataType
,
CDataType
,
kBlockSize
,
BlockGemmShape
>
,
BlockGemmProblem
<
ADataType
,
BDataType
,
CDataType
,
kBlockSize
,
BlockGemmShape
>
,
BlockGemmARegBGmemCRegV1DefaultPolicy
>
;
BlockGemmARegBGmemCRegV1DefaultPolicy
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
{
return
sizeof
(
BDataType
)
*
return
sizeof
(
BDataType
)
*
Policy
::
template
MakeBSmemBlockDescriptor
<
Problem
>().
get_element_space_size
();
Policy
::
template
MakeBSmemBlockDescriptor
<
Problem
>().
get_element_space_size
();
...
...
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
View file @
a16bb52d
...
@@ -24,19 +24,19 @@ struct BlockGemmASmemBSmemCRegV1
...
@@ -24,19 +24,19 @@ struct BlockGemmASmemBSmemCRegV1
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// C += A * B
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockWindow
Tmp
,
typename
BBlockWindow
Tmp
>
template
<
typename
CBlockTensor
,
typename
ABlockWindow
,
typename
BBlockWindow
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockWindow
Tmp
&
a_block_window
_tmp
,
const
ABlockWindow
&
a_block_window
,
const
BBlockWindow
Tmp
&
b_block_window
_tmp
)
const
const
BBlockWindow
&
b_block_window
)
const
{
{
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ABlockWindow
Tmp
::
DataType
>
&&
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ABlockWindow
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BBlockWindow
Tmp
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BBlockWindow
::
DataType
>
&&
std
::
is_same_v
<
CDataType
,
typename
CBlockTensor
::
DataType
>
,
std
::
is_same_v
<
CDataType
,
typename
CBlockTensor
::
DataType
>
,
"wrong!"
);
"wrong!"
);
constexpr
index_t
MPerBlock
=
ABlockWindow
Tmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
MPerBlock
=
ABlockWindow
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockWindow
Tmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockWindow
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockWindow
Tmp
{}.
get_window_lengths
()[
number
<
1
>
{}];
constexpr
index_t
KPerBlock
=
ABlockWindow
{}.
get_window_lengths
()[
number
<
1
>
{}];
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
KPerBlock
==
BlockGemmShape
::
kK
,
...
@@ -62,9 +62,9 @@ struct BlockGemmASmemBSmemCRegV1
...
@@ -62,9 +62,9 @@ struct BlockGemmASmemBSmemCRegV1
// construct A-warp-window
// construct A-warp-window
auto
a_warp_window_tmp
=
make_tile_window
(
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window
_tmp
.
get_bottom_tensor_view
(),
a_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kM
>
{},
number
<
WG
::
kK
>
{}),
make_tuple
(
number
<
WG
::
kM
>
{},
number
<
WG
::
kK
>
{}),
a_block_window
_tmp
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WG
::
kM
,
0
},
a_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WG
::
kM
,
0
},
make_static_tile_distribution
(
typename
WG
::
AWarpDstrEncoding
{}));
make_static_tile_distribution
(
typename
WG
::
AWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
#if 0 // FIXME: using array will cause register spill
...
@@ -97,9 +97,9 @@ struct BlockGemmASmemBSmemCRegV1
...
@@ -97,9 +97,9 @@ struct BlockGemmASmemBSmemCRegV1
// construct B-warp-window
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window
_tmp
.
get_bottom_tensor_view
(),
b_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kN
>
{},
number
<
WG
::
kK
>
{}),
make_tuple
(
number
<
WG
::
kN
>
{},
number
<
WG
::
kK
>
{}),
b_block_window
_tmp
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WG
::
kN
,
0
},
b_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WG
::
kN
,
0
},
make_static_tile_distribution
(
typename
WG
::
BWarpDstrEncoding
{}));
make_static_tile_distribution
(
typename
WG
::
BWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
#if 0 // FIXME: using array will cause register spill
...
@@ -200,12 +200,12 @@ struct BlockGemmASmemBSmemCRegV1
...
@@ -200,12 +200,12 @@ struct BlockGemmASmemBSmemCRegV1
}
}
// C = A * B
// C = A * B
template
<
typename
ABlockTensorTmp
,
typename
BBlockWindow
Tmp
>
template
<
typename
ABlockTensorTmp
,
typename
BBlockWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensorTmp
&
a_block_tensor_tmp
,
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensorTmp
&
a_block_tensor_tmp
,
const
BBlockWindow
Tmp
&
b_block_window
_tmp
)
const
const
BBlockWindow
&
b_block_window
)
const
{
{
auto
c_block_tensor
=
MakeCBlockTile
();
auto
c_block_tensor
=
MakeCBlockTile
();
operator
()(
c_block_tensor
,
a_block_tensor_tmp
,
b_block_window
_tmp
);
operator
()(
c_block_tensor
,
a_block_tensor_tmp
,
b_block_window
);
return
c_block_tensor
;
return
c_block_tensor
;
}
}
};
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
a16bb52d
...
@@ -3,12 +3,13 @@
...
@@ -3,12 +3,13 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include <iostream>
#include <iostream>
#include <string>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
...
@@ -17,20 +18,19 @@ struct GemmKernel
...
@@ -17,20 +18,19 @@ struct GemmKernel
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
kBlockSize
;
using
ALayout
=
remove_cvref_t
<
typename
GemmPipeline
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
BLayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
CLayout
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
BlockSize
;
using
CAccDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
CDataType
>
;
using
CODataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
using
LayoutA
=
remove_cvref_t
<
typename
GemmPipeline
::
LayoutA
>
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
LayoutB
=
remove_cvref_t
<
typename
GemmPipeline
::
LayoutB
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
LayoutC
=
remove_cvref_t
<
typename
GemmPipeline
::
LayoutC
>
;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
__host__
static
constexpr
auto
GridSize
(
index_t
M
_size
,
index_t
N
_size
,
index_t
Batch
_size
)
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
K
Batch
)
{
{
return
TilePartitioner
::
GridSize
(
M
_size
,
N_size
,
Batch
_size
);
return
TilePartitioner
::
GridSize
(
M
,
N
,
K
Batch
);
}
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
...
@@ -40,34 +40,30 @@ struct GemmKernel
...
@@ -40,34 +40,30 @@ struct GemmKernel
const
void
*
a_ptr
;
const
void
*
a_ptr
;
const
void
*
b_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
void
*
c_ptr
;
index_t
M
;
float
epsilon
;
index_t
N
;
index_t
K
;
ck_tile
::
index_t
M
;
index_t
stride_A
;
ck_tile
::
index_t
N
;
index_t
stride_B
;
ck_tile
::
index_t
K
;
index_t
stride_C
;
ck_tile
::
index_t
stride_A
;
ck_tile
::
index_t
stride_B
;
ck_tile
::
index_t
stride_C
;
};
};
CK_TILE_HOST
static
constexpr
GemmCommonKargs
MakeKargs
(
const
void
*
a_ptr
,
CK_TILE_HOST
static
constexpr
GemmCommonKargs
MakeKargs
(
const
void
*
a_ptr
,
const
void
*
b_ptr
,
const
void
*
b_ptr
,
void
*
c_ptr
,
void
*
c_ptr
,
float
epsilon
,
index_t
M
,
ck_tile
::
index_t
M
,
index_t
N
,
ck_tile
::
index_t
N
,
index_t
K
,
ck_tile
::
index_t
K
,
index_t
stride_A
,
ck_tile
::
index_t
stride_A
,
index_t
stride_B
,
ck_tile
::
index_t
stride_B
,
index_t
stride_C
)
ck_tile
::
index_t
stride_C
)
{
{
return
GemmCommonKargs
{
a_ptr
,
b_ptr
,
c_ptr
,
epsilon
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
};
return
GemmCommonKargs
{
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
};
}
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
{
return
ck_tile
::
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
}
CK_TILE_DEVICE
void
operator
()(
GemmCommonKargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
GemmCommonKargs
kargs
)
const
...
@@ -78,13 +74,13 @@ struct GemmKernel
...
@@ -78,13 +74,13 @@ struct GemmKernel
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
// Convert pointers to tensor views
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
]()
{
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
Layout
A
,
tensor_layout
::
gemm
::
Column
Major
>
)
if
constexpr
(
std
::
is_same_v
<
A
Layout
,
tensor_layout
::
gemm
::
Row
Major
>
)
{
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
Alignment
A
>
{},
number
<
GemmPipeline
::
VectorSize
A
>
{},
number
<
1
>
{});
number
<
1
>
{});
}
}
else
else
...
@@ -92,29 +88,29 @@ struct GemmKernel
...
@@ -92,29 +88,29 @@ struct GemmKernel
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
GemmPipeline
::
AlignmentA
>
{},
number
<
1
>
{},
number
<
1
>
{});
number
<
1
>
{});
}
}
}();
}();
auto
b_tensor_view
=
[
&
]()
{
auto
b_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
Layout
B
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
B
Layout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_B
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
GemmPipeline
::
AlignmentB
>
{},
number
<
1
>
{},
number
<
1
>
{});
number
<
1
>
{});
}
}
else
else
{
// Default NK layout
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_B
,
1
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
Alignment
B
>
{},
number
<
GemmPipeline
::
VectorSize
B
>
{},
number
<
1
>
{});
number
<
1
>
{});
}
}
}();
}();
...
@@ -122,10 +118,12 @@ struct GemmKernel
...
@@ -122,10 +118,12 @@ struct GemmKernel
auto
a_pad_view
=
pad_tensor_view
(
auto
a_pad_view
=
pad_tensor_view
(
a_tensor_view
,
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
0
,
// somehow clang-format is splitting below line into multiple.
GemmPipeline
::
kPadA
?
1
:
0
>
{});
// clang-format off
sequence
<
false
,
GemmPipeline
::
kPadA
>
{});
// clang-format on
auto
AB
lock
W
indow
=
make_tile_window
(
auto
a_b
lock
_w
indow
=
make_tile_window
(
a_pad_view
,
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
{
i_m
,
0
});
...
@@ -133,10 +131,11 @@ struct GemmKernel
...
@@ -133,10 +131,11 @@ struct GemmKernel
auto
b_pad_view
=
pad_tensor_view
(
auto
b_pad_view
=
pad_tensor_view
(
b_tensor_view
,
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
0
,
// clang-format off
GemmPipeline
::
kPadB
?
1
:
0
>
{});
sequence
<
false
,
GemmPipeline
::
kPadB
>
{});
// clang-format on
auto
BB
lock
W
indow
=
make_tile_window
(
auto
b_b
lock
_w
indow
=
make_tile_window
(
b_pad_view
,
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
{
i_n
,
0
});
...
@@ -144,20 +143,21 @@ struct GemmKernel
...
@@ -144,20 +143,21 @@ struct GemmKernel
// allocate LDS
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
index_t
num_loop
=
(
kargs
.
K
+
TilePartitioner
::
kK
-
1
)
/
TilePartitioner
::
kK
;
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
auto
acc
=
GemmPipeline
{}(
ABlockWindow
,
BBlockWindow
,
num_loop
,
smem_ptr
);
CODataType
*
c_start
=
static_cast
<
CODataType
*>
(
kargs
.
c_ptr
);
// Run GEMM cooperatively by whole wokrgroup.
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
]()
{
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
Layout
C
,
tensor_layout
::
gemm
::
Column
Major
>
)
if
constexpr
(
std
::
is_same_v
<
C
Layout
,
tensor_layout
::
gemm
::
Row
Major
>
)
{
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
Alignment
C
>
{},
number
<
GemmPipeline
::
VectorSize
C
>
{},
number
<
1
>
{});
number
<
1
>
{});
}
}
else
else
...
@@ -165,8 +165,8 @@ struct GemmKernel
...
@@ -165,8 +165,8 @@ struct GemmKernel
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
GemmPipeline
::
AlignmentC
>
{},
number
<
1
>
{},
number
<
1
>
{});
number
<
1
>
{});
}
}
}();
}();
...
@@ -174,14 +174,15 @@ struct GemmKernel
...
@@ -174,14 +174,15 @@ struct GemmKernel
auto
c_pad_view
=
pad_tensor_view
(
auto
c_pad_view
=
pad_tensor_view
(
c_tensor_view
,
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
0
,
// clang-format off
GemmPipeline
::
kPadC
?
1
:
0
>
{});
sequence
<
false
,
GemmPipeline
::
kPadC
>
{});
auto
CBlockWindow_pad
=
make_tile_window
(
// clang-format on
auto
c_block_window
=
make_tile_window
(
c_pad_view
,
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
{
i_m
,
i_n
});
EpiloguePipeline
{}(
CB
lock
W
indow
_pad
,
acc
);
EpiloguePipeline
{}(
c_b
lock
_w
indow
,
c_block_tile
);
}
}
};
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
View file @
a16bb52d
...
@@ -9,26 +9,30 @@ namespace ck_tile {
...
@@ -9,26 +9,30 @@ namespace ck_tile {
template
<
typename
BlockGemmShape_
>
template
<
typename
BlockGemmShape_
>
struct
GemmTilePartitioner
struct
GemmTilePartitioner
{
{
using
BlockGemmShape
=
ck_tile
::
remove_cvref_t
<
BlockGemmShape_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
ck_tile
::
index_t
kM
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
kM
=
BlockGemmShape
::
kM
;
static
constexpr
ck_tile
::
index_t
kN
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kN
=
BlockGemmShape
::
kN
;
static
constexpr
ck_tile
::
index_t
kK
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
kK
=
BlockGemmShape
::
kK
;
CK_TILE_HOST
static
constexpr
auto
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
batch_size
)
GridSize
(
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
,
ck_tile
::
index_t
batch_size
)
{
{
ck_tile
::
index_t
GridDimX
=
(
M
+
kM
-
1
)
/
kM
;
index_t
GridDimX
=
(
M
+
kM
-
1
)
/
kM
;
ck_tile
::
index_t
GridDimY
=
(
N
+
kN
-
1
)
/
kN
;
index_t
GridDimY
=
(
N
+
kN
-
1
)
/
kN
;
ck_tile
::
index_t
GridDimZ
=
batch_size
;
index_t
GridDimZ
=
batch_size
;
return
dim3
(
GridDimX
,
GridDimY
,
GridDimZ
);
return
dim3
(
GridDimX
,
GridDimY
,
GridDimZ
);
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLoopNum
(
index_t
K
)
{
return
integer_divide_ceil
(
K
,
kK
);
}
CK_TILE_DEVICE
auto
operator
()()
CK_TILE_DEVICE
auto
operator
()()
{
{
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
*
kM
);
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
*
kM
);
const
index_t
iN
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
kN
);
const
index_t
iN
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
kN
);
return
ck_tile
::
make_tuple
(
iM
,
iN
);
return
make_tuple
(
iM
,
iN
);
}
}
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
0 → 100644
View file @
a16bb52d
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template
<
typename
Problem
>
struct
BaseGemmPipelineAgBgCrMem
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
// TODO: Is this 32K value gfx9 arch specific?
static
constexpr
index_t
MinMemInFlyBytes
=
32768
;
static
constexpr
index_t
WgpPerCU
=
(
4
*
get_warp_size
()
/
BlockSize
)
>=
1
?
4
*
get_warp_size
()
/
BlockSize
:
1
;
static
constexpr
index_t
FullMemBandPrefetchStages
=
integer_divide_ceil
(
MinMemInFlyBytes
/
WgpPerCU
,
(
MPerBlock
*
sizeof
(
ADataType
)
+
NPerBlock
*
sizeof
(
BDataType
))
*
KPerBlock
);
static
constexpr
index_t
PrefetchStages
=
FullMemBandPrefetchStages
>=
2
?
FullMemBandPrefetchStages
<=
8
?
FullMemBandPrefetchStages
:
8
:
2
;
static
constexpr
index_t
LocalPrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
PrefetchStages
;
CK_TILE_HOST
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
CK_TILE_HOST
static
constexpr
TailNumber
GetBlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
PrefetchStages
==
1
)
{
return
TailNumber
::
One
;
}
else
if
(
num_loop
%
PrefetchStages
==
2
)
{
return
TailNumber
::
Two
;
}
else
if
(
num_loop
%
PrefetchStages
==
3
)
{
return
TailNumber
::
Three
;
}
else
if
(
num_loop
%
PrefetchStages
==
4
)
{
return
TailNumber
::
Four
;
}
else
if
(
num_loop
%
PrefetchStages
==
5
)
{
return
TailNumber
::
Five
;
}
else
if
(
num_loop
%
PrefetchStages
==
6
)
{
return
TailNumber
::
Six
;
}
else
if
(
num_loop
%
PrefetchStages
==
7
)
{
return
TailNumber
::
Seven
;
}
else
{
return
TailNumber
::
Full
;
}
}
};
// Maximum Global Memory throughput pipeline with >=32KB data in fly
// GlobalPrefetchStages: >=2
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
GemmPipelineAgBgCrMem
:
public
BaseGemmPipelineAgBgCrMem
<
Problem
>
{
using
Base
=
BaseGemmPipelineAgBgCrMem
<
Problem
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Policy
::
template
GetBlockGemm
<
Problem
>())
>
;
using
I0
=
number
<
0
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Problem
::
VectorSizeA
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
bool
kPadA
=
Problem
::
kPadA
;
static
constexpr
bool
kPadB
=
Problem
::
kPadB
;
static
constexpr
bool
kPadC
=
Problem
::
kPadC
;
// Where is the right place for HasHotLoop and TailNum ???
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
using
Base
::
PrefetchStages
;
CK_TILE_HOST_DEVICE
constexpr
index_t
GetStaticLdsSize
()
{
return
integer_divide_ceil
(
sizeof
(
ADataType
)
*
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
16
+
sizeof
(
BDataType
)
*
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
{
};
template
<
>
struct
PipelineImpl
<
GemmPipelineScheduler
::
Intrawave
>
{
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
>
CK_TILE_DEVICE
void
GlobalPrefetch
(
DstBlockTile
&
dst_block_tile
,
SrcTileWindow
&
dram_tile_window
)
const
{
load_tile
(
dst_block_tile
,
dram_tile_window
);
move_tile_window
(
dram_tile_window
,
{
0
,
KPerBlock
});
}
template
<
typename
DstTileWindow
,
typename
SrcBlockTile
,
typename
ElementFunction
>
CK_TILE_DEVICE
void
LocalPrefill
(
DstTileWindow
&
lds_tile_window
,
const
SrcBlockTile
&
src_block_tile
,
const
ElementFunction
&
element_func
)
const
{
const
auto
block_tile_tmp
=
tile_elementwise_in
(
element_func
,
src_block_tile
);
store_tile
(
lds_tile_window
,
block_tile_tmp
);
}
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cvref_t
<
typename
BDramBlockWindowTmp
::
DataType
>>
,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
// TODO: LDS alignment should come from Policy!
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
));
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
auto
b_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
,
b_lds_block_desc
);
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
// B DRAM tile window for load
auto
b_copy_dram_window
=
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// B LDS tile window for store
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
// A LDS tile for block GEMM
auto
a_lds_gemm_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
// B LDS tile for block GEMM
auto
b_lds_gemm_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
// Block GEMM
constexpr
auto
block_gemm
=
BlockGemm
();
auto
c_block_tile
=
block_gemm
.
MakeCBlockTile
();
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
());
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
{}));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
{}));
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
// main body
if
constexpr
(
HasHotLoop
)
{
index_t
i
=
0
;
do
{
static_for
<
0
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
i
+=
PrefetchStages
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
auto
HotLoopTail
=
[
&
](
auto
tail_num
)
{
static_for
<
1
,
tail_num
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
});
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
};
if
constexpr
(
TailNum
==
TailNumber
::
One
)
{
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Two
)
{
HotLoopTail
(
number
<
2
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Three
)
{
HotLoopTail
(
number
<
3
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Four
)
{
HotLoopTail
(
number
<
4
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Five
)
{
HotLoopTail
(
number
<
5
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Six
)
{
HotLoopTail
(
number
<
6
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Seven
)
{
HotLoopTail
(
number
<
7
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
HotLoopTail
(
number
<
PrefetchStages
>
{});
}
return
c_block_tile
;
}
};
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
a_element_func
,
b_dram_block_window_tmp
,
b_element_func
,
num_loop
,
p_smem
);
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
index_t
num_loop
,
void
*
p_smem
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
[](
const
ADataType
&
a
)
{
return
a
;
},
b_dram_block_window_tmp
,
[](
const
BDataType
&
b
)
{
return
b
;
},
num_loop
,
p_smem
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
0 → 100644
View file @
a16bb52d
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ostream>
#include "ck_tile/core.hpp"
namespace
ck_tile
{
enum
struct
GemmPipelineScheduler
{
Intrawave
,
Interwave
,
};
enum
struct
TailNumber
{
// Single / Double buffer pipeline
Odd
,
Even
,
// Long prefetch pipeline, up to 8
One
,
Two
,
Three
,
Four
,
Five
,
Six
,
Seven
,
// Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
Empty
,
// Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
// prefetchstages
Full
,
};
}
// namespace ck_tile
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck_tile
::
GemmPipelineScheduler
&
s
)
{
switch
(
s
)
{
case
ck_tile
::
GemmPipelineScheduler
::
Intrawave
:
os
<<
"Intrawave"
;
break
;
case
ck_tile
::
GemmPipelineScheduler
::
Interwave
:
os
<<
"Interwave"
;
break
;
default:
os
<<
""
;
}
return
os
;
}
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck_tile
::
TailNumber
&
s
)
{
switch
(
s
)
{
case
ck_tile
::
TailNumber
::
Odd
:
os
<<
"Odd"
;
break
;
case
ck_tile
::
TailNumber
::
Even
:
os
<<
"Even"
;
break
;
case
ck_tile
::
TailNumber
::
One
:
os
<<
"One"
;
break
;
case
ck_tile
::
TailNumber
::
Two
:
os
<<
"Two"
;
break
;
case
ck_tile
::
TailNumber
::
Three
:
os
<<
"Three"
;
break
;
case
ck_tile
::
TailNumber
::
Four
:
os
<<
"Four"
;
break
;
case
ck_tile
::
TailNumber
::
Five
:
os
<<
"Five"
;
break
;
case
ck_tile
::
TailNumber
::
Six
:
os
<<
"Six"
;
break
;
case
ck_tile
::
TailNumber
::
Seven
:
os
<<
"Seven"
;
break
;
case
ck_tile
::
TailNumber
::
Empty
:
os
<<
"Empty"
;
break
;
case
ck_tile
::
TailNumber
::
Full
:
os
<<
"Full"
;
break
;
default:
os
<<
""
;
}
return
os
;
}
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
a16bb52d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -19,27 +19,27 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -19,27 +19,27 @@ struct GemmPipelineAGmemBGmemCRegV1
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
Alignment
A
=
Problem
::
Alignment
A
;
static
constexpr
index_t
VectorSize
A
=
Problem
::
VectorSize
A
;
static
constexpr
index_t
Alignment
B
=
Problem
::
Alignment
B
;
static
constexpr
index_t
VectorSize
B
=
Problem
::
VectorSize
B
;
static
constexpr
index_t
Alignment
C
=
Problem
::
Alignment
C
;
static
constexpr
index_t
VectorSize
C
=
Problem
::
VectorSize
C
;
static
constexpr
bool
kPadA
=
Problem
::
kPadA
;
static
constexpr
bool
kPadA
=
Problem
::
kPadA
;
static
constexpr
bool
kPadB
=
Problem
::
kPadB
;
static
constexpr
bool
kPadB
=
Problem
::
kPadB
;
static
constexpr
bool
kPadC
=
Problem
::
kPadC
;
static
constexpr
bool
kPadC
=
Problem
::
kPadC
;
using
LayoutA
=
remove_cvref_t
<
typename
Problem
::
LayoutA
>
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
using
LayoutB
=
remove_cvref_t
<
typename
Problem
::
LayoutB
>
;
using
LayoutC
=
remove_cvref_t
<
typename
Problem
::
LayoutC
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
{
{
return
ck_tile
::
integer_divide_ceil
(
return
integer_divide_ceil
(
sizeof
(
ADataType
)
*
sizeof
(
ADataType
)
*
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
16
)
*
...
@@ -48,7 +48,7 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -48,7 +48,7 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
a16bb52d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -71,8 +71,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -71,8 +71,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
{
using
namespace
ck_tile
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
...
@@ -93,7 +91,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -93,7 +91,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeA
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeA
()
{
{
constexpr
index_t
smem_size_a
=
sizeof
(
typename
Problem
::
ADataType
)
*
constexpr
index_t
smem_size_a
=
sizeof
(
typename
Problem
::
ADataType
)
*
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
...
@@ -101,7 +99,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -101,7 +99,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeB
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeB
()
{
{
constexpr
index_t
smem_size_b
=
sizeof
(
typename
Problem
::
BDataType
)
*
constexpr
index_t
smem_size_b
=
sizeof
(
typename
Problem
::
BDataType
)
*
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
...
@@ -109,7 +107,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -109,7 +107,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
{
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
View file @
a16bb52d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -25,9 +25,9 @@ struct GemmPipelineAGmemBGmemCRegV2
...
@@ -25,9 +25,9 @@ struct GemmPipelineAGmemBGmemCRegV2
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
{
return
ck_tile
::
integer_divide_ceil
(
return
integer_divide_ceil
(
sizeof
(
ADataType
)
*
sizeof
(
ADataType
)
*
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
16
)
*
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment