Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
04006d5f
Commit
04006d5f
authored
Aug 25, 2024
by
ThomasNing
Browse files
Fix: Clang Format, API fixed from fmha
parent
c2b7f8df
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
379 additions
and
356 deletions
+379
-356
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+141
-103
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+24
-20
include/ck_tile/host/check_err.hpp
include/ck_tile/host/check_err.hpp
+5
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+25
-59
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+17
-38
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
...m/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
+3
-1
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+93
-54
include/ck_tile/ops/gemm/kernel/gemm_matrix_type.hpp
include/ck_tile/ops/gemm/kernel/gemm_matrix_type.hpp
+16
-31
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
+23
-20
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
...gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
+4
-3
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...lock_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+13
-10
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
...ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
+3
-4
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
+12
-8
No files found.
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
04006d5f
...
@@ -12,80 +12,91 @@
...
@@ -12,80 +12,91 @@
#include <tuple>
#include <tuple>
/*
/*
create_args is a function
create_args is a function
*/
*/
auto
create_args
(
int
argc
,
char
*
argv
[])
{
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"b"
,
"1"
,
"batch size"
)
arg_parser
.
insert
(
"b"
,
"1"
,
"batch size"
)
.
insert
(
"m"
,
"1024"
,
"m dimension"
)
.
insert
(
"m"
,
"1024"
,
"m dimension"
)
.
insert
(
"n"
,
"2048"
,
"n dimension"
)
.
insert
(
"n"
,
"2048"
,
"n dimension"
)
.
insert
(
"k"
,
"32"
,
"k dimension"
)
.
insert
(
"k"
,
"32"
,
"k dimension"
)
.
insert
(
"stride_a"
,
"0"
,
"stride on apply the m,k A block"
)
.
insert
(
"stride_a"
,
"0"
,
"stride on apply the m,k A block"
)
.
insert
(
"stride_b"
,
"0"
,
"stride on apply the n,k B block"
)
.
insert
(
"stride_b"
,
"0"
,
"stride on apply the n,k B block"
)
.
insert
(
"stride_c"
,
"0"
,
"stride on apply the m,n C block"
)
.
insert
(
"stride_c"
,
"0"
,
"stride on apply the m,n C block"
)
.
insert
(
"grouped"
,
"0"
,
"bool condition on whether it is a grouped gemm"
)
.
insert
(
"grouped"
,
"0"
,
"bool condition on whether it is a grouped gemm"
)
.
insert
(
"grouped_dimension_m"
,
"0"
,
"Fill in the desired dimension when enable grouped gemm"
)
.
insert
(
.
insert
(
"grouped_dimension_n"
,
"0"
,
"Fill in the desired dimension when enable grouped gemm"
)
"grouped_dimension_m"
,
"0"
,
"Fill in the desired dimension when enable grouped gemm"
)
.
insert
(
"grouped_dimension_k"
,
"0"
,
"Fill in the desired dimension when enable grouped gemm"
)
.
insert
(
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
"grouped_dimension_n"
,
"0"
,
"Fill in the desired dimension when enable grouped gemm"
)
.
insert
(
"e"
,
"1e-5"
,
"epsilon"
)
.
insert
(
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
"grouped_dimension_k"
,
"0"
,
"Fill in the desired dimension when enable grouped gemm"
)
.
insert
(
"following_op"
,
"no"
,
"combined_op. bias/relu/gelu..."
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"warmup"
,
"10"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"e"
,
"1e-5"
,
"epsilon"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
);
.
insert
(
"following_op"
,
"no"
,
"combined_op. bias/relu/gelu..."
)
.
insert
(
"warmup"
,
"10"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
}
template
<
typename
Layouts
>
template
<
typename
Layouts
>
float
gemm_calc
(
gemm_basic_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
float
gemm_calc
(
gemm_basic_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// ToDo: This will be modified by the codegen code later.
// ToDo: This will be modified by the codegen code later.
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadC
=
false
;
constexpr
bool
kPadC
=
false
;
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
// ===============================================
// ===============================================
using
Shape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
using
Shape
=
ck_tile
::
TileGemmShape
NewGemm
<
ck_tile
::
sequence
<
M_
Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_
Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
_Tile
,
N_Warp
_Tile
,
K_Warp
_Tile
>
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>
>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
Shape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
Shape
>
;
using
PipelineProblem
=
ck_tile
::
BlockGemmPipelineProblem
<
XDataType
,
YDataType
,
AccDataType
,
Shape
,
using
PipelineProblem
=
ck_tile
::
kPadA
,
kPadB
,
kPadC
>
;
BlockGemmPipelineProblem
<
XDataType
,
YDataType
,
AccDataType
,
Shape
,
kPadA
,
kPadB
,
kPadC
>
;
// The GemmPipeline should also come from the Codegen.
// The GemmPipeline should also come from the Codegen.
using
GemmPipeline
=
ck_tile
::
BlockGemmPipelineAGmemBGmemCRegV1
<
PipelineProblem
>
;
using
GemmPipeline
=
ck_tile
::
BlockGemmPipelineAGmemBGmemCRegV1
<
PipelineProblem
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ODataType
,
kPadA
,
kPadB
>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
ODataType
,
kPadA
,
kPadB
>>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
,
Layouts
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
,
Layouts
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_x
,
args
.
p_x
,
args
.
p_y
,
args
.
p_z
,
args
.
batch_size
,
args
.
epsilon
,
args
.
M
,
args
.
N
,
args
.
p_y
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
args
.
p_z
,
);
args
.
batch_size
,
args
.
epsilon
,
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
batch_size
);
args
.
M
,
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
batch_size
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
float
ave_time
=
ck_tile
::
launch_kernel
(
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
...
@@ -94,81 +105,103 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s) {
...
@@ -94,81 +105,103 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s) {
}
}
template
<
typename
DataType
,
typename
Layouts
>
template
<
typename
DataType
,
typename
Layouts
>
float
OperatorExecution
(
ck_tile
::
DeviceMem
&
x_buf
,
ck_tile
::
DeviceMem
&
y_buf
,
float
OperatorExecution
(
ck_tile
::
DeviceMem
&
x_buf
,
ck_tile
::
DeviceMem
&
z_buf
,
ck_tile
::
DeviceMem
&
y_buf
,
const
ck_tile
::
ArgParser
&
arg_parser
){
ck_tile
::
DeviceMem
&
z_buf
,
const
ck_tile
::
ArgParser
&
arg_parser
)
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
{
if
(
data_type
!=
DataTypeTraits
<
DataType
>::
name
)
{
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
cerr
<<
"Data type mismatch: expected "
<<
DataTypeTraits
<
DataType
>::
name
<<
", got "
if
(
data_type
!=
DataTypeTraits
<
DataType
>::
name
)
{
std
::
cerr
<<
"Data type mismatch: expected "
<<
DataTypeTraits
<
DataType
>::
name
<<
", got "
<<
data_type
<<
std
::
endl
;
<<
data_type
<<
std
::
endl
;
return
-
1
;
// Or handle the error appropriately
return
-
1
;
// Or handle the error appropriately
}
}
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
ck_tile
::
index_t
batch_size
=
arg_parser
.
get_int
(
"b"
);
ck_tile
::
index_t
batch_size
=
arg_parser
.
get_int
(
"b"
);
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
ck_tile
::
index_t
stride_a
=
arg_parser
.
get_int
(
"stride_a"
);
ck_tile
::
index_t
stride_a
=
arg_parser
.
get_int
(
"stride_a"
);
ck_tile
::
index_t
stride_b
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_b
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_c
=
arg_parser
.
get_int
(
"stride_c"
);
ck_tile
::
index_t
stride_c
=
arg_parser
.
get_int
(
"stride_c"
);
gemm_basic_args
args
;
gemm_basic_args
args
;
args
.
p_x
=
x_buf
.
GetDeviceBuffer
();
args
.
p_x
=
x_buf
.
GetDeviceBuffer
();
args
.
p_y
=
y_buf
.
GetDeviceBuffer
();
args
.
p_y
=
y_buf
.
GetDeviceBuffer
();
args
.
p_z
=
z_buf
.
GetDeviceBuffer
();
args
.
p_z
=
z_buf
.
GetDeviceBuffer
();
args
.
epsilon
=
epsilon
;
args
.
epsilon
=
epsilon
;
args
.
batch_size
=
batch_size
;
args
.
batch_size
=
batch_size
;
args
.
M
=
M
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
N
=
N
;
args
.
K
=
K
;
args
.
K
=
K
;
// Only set stride_M and stride_N if they are non-zero and not equal to K.
// Only set stride_M and stride_N if they are non-zero and not equal to K.
if
(
stride_a
!=
0
)
{
if
(
stride_a
!=
0
)
{
args
.
stride_A
=
stride_a
;
args
.
stride_A
=
stride_a
;
}
else
{
}
args
.
stride_A
=
[
&
](){
else
if
constexpr
(
Layouts
::
LayoutA
==
ck_tile
::
MatrixALayout
::
KM
)
{
{
args
.
stride_A
=
[
&
]()
{
if
constexpr
(
Layouts
::
LayoutA
==
ck_tile
::
MatrixALayout
::
KM
)
{
return
M
;
return
M
;
}
else
{
}
else
{
return
K
;
return
K
;
}
}
}();
}();
}
}
if
(
stride_b
!=
0
)
{
if
(
stride_b
!=
0
)
{
args
.
stride_B
=
stride_b
;
args
.
stride_B
=
stride_b
;
}
else
{
}
args
.
stride_B
=
[
&
](){
else
if
constexpr
(
Layouts
::
LayoutB
==
ck_tile
::
MatrixBLayout
::
KN
)
{
{
args
.
stride_B
=
[
&
]()
{
if
constexpr
(
Layouts
::
LayoutB
==
ck_tile
::
MatrixBLayout
::
KN
)
{
return
N
;
return
N
;
}
else
{
}
else
{
return
K
;
return
K
;
}
}
}();
}();
}
}
if
(
stride_c
!=
0
)
{
if
(
stride_c
!=
0
)
{
args
.
stride_C
=
stride_c
;
args
.
stride_C
=
stride_c
;
}
else
{
}
args
.
stride_C
=
[
&
](){
else
if
constexpr
(
Layouts
::
LayoutC
==
ck_tile
::
MatrixCLayout
::
NM
)
{
{
args
.
stride_C
=
[
&
]()
{
if
constexpr
(
Layouts
::
LayoutC
==
ck_tile
::
MatrixCLayout
::
NM
)
{
return
M
;
return
M
;
}
else
{
}
else
{
return
N
;
return
N
;
}
}
}();
}();
}
}
float
ave_time
=
gemm_calc
<
Layouts
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
});
float
ave_time
=
gemm_calc
<
Layouts
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
});
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
M
*
K
+
sizeof
(
YDataType
)
*
N
*
K
+
std
::
size_t
num_byte
=
sizeof
(
ODataType
)
*
M
*
N
;
sizeof
(
XDataType
)
*
M
*
K
+
sizeof
(
YDataType
)
*
N
*
K
+
sizeof
(
ODataType
)
*
M
*
N
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
"The overall perfomance of the GEMM with "
<<
"["
<<
data_type
<<
"]"
std
::
cout
<<
"The overall perfomance of the GEMM with "
<<
"["
<<
data_type
<<
"]"
<<
"batch size: "
<<
batch_size
<<
". m:"
<<
M
<<
",n:"
<<
N
<<
", k:"
<<
K
<<
"batch size: "
<<
batch_size
<<
". m:"
<<
M
<<
",n:"
<<
N
<<
", k:"
<<
K
<<
"is:
\n
"
;
<<
"is:
\n
"
;
std
::
cout
<<
"Running time :"
<<
ave_time
<<
"ms, Throughput"
<<
gb_per_sec
<<
"GB/s
\n
"
std
::
cout
<<
"Running time :"
<<
ave_time
<<
"ms, Throughput"
<<
gb_per_sec
<<
"GB/s
\n
"
...
@@ -177,16 +210,17 @@ float OperatorExecution(ck_tile::DeviceMem& x_buf, ck_tile::DeviceMem& y_buf,
...
@@ -177,16 +210,17 @@ float OperatorExecution(ck_tile::DeviceMem& x_buf, ck_tile::DeviceMem& y_buf,
return
ave_time
;
return
ave_time
;
}
}
int
main
(
int
argc
,
char
*
argv
[])
{
int
main
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
if
(
!
result
)
return
-
1
;
return
-
1
;
bool
grouped_enable
=
arg_parser
.
get_bool
(
"grouped"
);
bool
grouped_enable
=
arg_parser
.
get_bool
(
"grouped"
);
std
::
string
following_op_descrp
=
arg_parser
.
get_str
(
"following_op"
);
std
::
string
following_op_descrp
=
arg_parser
.
get_str
(
"following_op"
);
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
constexpr
ck_tile
::
MatrixALayout
matrix_a_layout
=
ck_tile
::
MatrixALayout
::
MK
;
constexpr
ck_tile
::
MatrixALayout
matrix_a_layout
=
ck_tile
::
MatrixALayout
::
MK
;
constexpr
ck_tile
::
MatrixBLayout
matrix_b_layout
=
ck_tile
::
MatrixBLayout
::
NK
;
constexpr
ck_tile
::
MatrixBLayout
matrix_b_layout
=
ck_tile
::
MatrixBLayout
::
NK
;
...
@@ -194,12 +228,15 @@ int main(int argc, char* argv[]) {
...
@@ -194,12 +228,15 @@ int main(int argc, char* argv[]) {
using
Layouts
=
LayoutConfig
<
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
;
using
Layouts
=
LayoutConfig
<
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
;
// host verify
// host verify
std
::
vector
<
int
>
x_dimensions
=
(
matrix_a_layout
==
ck_tile
::
MatrixALayout
::
MK
)
?
std
::
vector
<
int
>
x_dimensions
=
(
matrix_a_layout
==
ck_tile
::
MatrixALayout
::
MK
)
std
::
vector
<
int
>
{
M
,
K
}
:
std
::
vector
<
int
>
{
K
,
M
};
?
std
::
vector
<
int
>
{
M
,
K
}
std
::
vector
<
int
>
y_dimensions
=
(
matrix_b_layout
==
ck_tile
::
MatrixBLayout
::
NK
)
?
:
std
::
vector
<
int
>
{
K
,
M
};
std
::
vector
<
int
>
{
N
,
K
}
:
std
::
vector
<
int
>
{
K
,
N
};
std
::
vector
<
int
>
y_dimensions
=
(
matrix_b_layout
==
ck_tile
::
MatrixBLayout
::
NK
)
std
::
vector
<
int
>
z_dimensions
=
(
matrix_c_layout
==
ck_tile
::
MatrixCLayout
::
MN
)
?
?
std
::
vector
<
int
>
{
N
,
K
}
std
::
vector
<
int
>
{
M
,
N
}
:
std
::
vector
<
int
>
{
N
,
M
};
:
std
::
vector
<
int
>
{
K
,
N
};
std
::
vector
<
int
>
z_dimensions
=
(
matrix_c_layout
==
ck_tile
::
MatrixCLayout
::
MN
)
?
std
::
vector
<
int
>
{
M
,
N
}
:
std
::
vector
<
int
>
{
N
,
M
};
ck_tile
::
HostTensor
<
XDataType
>
x_host
(
x_dimensions
);
ck_tile
::
HostTensor
<
XDataType
>
x_host
(
x_dimensions
);
ck_tile
::
HostTensor
<
YDataType
>
y_host
(
y_dimensions
);
ck_tile
::
HostTensor
<
YDataType
>
y_host
(
y_dimensions
);
...
@@ -209,7 +246,7 @@ int main(int argc, char* argv[]) {
...
@@ -209,7 +246,7 @@ int main(int argc, char* argv[]) {
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
5.
f
,
5.
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
5.
f
,
5.
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
YDataType
>
{
-
5.
f
,
5.
f
}(
y_host
);
ck_tile
::
FillUniformDistribution
<
YDataType
>
{
-
5.
f
,
5.
f
}(
y_host
);
ck_tile
::
DeviceMem
x_buf
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
x_buf
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
z_buf
(
z_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
z_buf
(
z_host_dev
.
get_element_space_size_in_bytes
());
...
@@ -217,16 +254,18 @@ int main(int argc, char* argv[]) {
...
@@ -217,16 +254,18 @@ int main(int argc, char* argv[]) {
x_buf
.
ToDevice
(
x_host
.
data
());
x_buf
.
ToDevice
(
x_host
.
data
());
y_buf
.
ToDevice
(
y_host
.
data
());
y_buf
.
ToDevice
(
y_host
.
data
());
if
(
grouped_enable
||
following_op_descrp
!=
"no"
)
{
if
(
grouped_enable
||
following_op_descrp
!=
"no"
)
{
std
::
cerr
<<
"Other category of the GEMM is unsupported for now!"
<<
std
::
endl
;
std
::
cerr
<<
"Other category of the GEMM is unsupported for now!"
<<
std
::
endl
;
return
-
1
;
return
-
1
;
}
}
OperatorExecution
<
ck_tile
::
half_t
,
Layouts
>
(
x_buf
,
y_buf
,
z_buf
,
arg_parser
);
OperatorExecution
<
ck_tile
::
half_t
,
Layouts
>
(
x_buf
,
y_buf
,
z_buf
,
arg_parser
);
bool
pass
=
true
;
bool
pass
=
true
;
if
(
arg_parser
.
get_bool
(
"v"
))
{
if
(
arg_parser
.
get_bool
(
"v"
))
{
// ToDo: Will Add the Element Op (bias) verification in the future.
// ToDo: Will Add the Element Op (bias) verification in the future.
ck_tile
::
reference_gemm
<
XDataType
,
YDataType
,
AccDataType
,
ODataType
>
(
ck_tile
::
reference_gemm
<
XDataType
,
YDataType
,
AccDataType
,
ODataType
>
(
x_host
,
y_host
,
z_host_ref
,
matrix_a_layout
);
x_host
,
y_host
,
z_host_ref
,
matrix_a_layout
);
...
@@ -239,7 +278,6 @@ int main(int argc, char* argv[]) {
...
@@ -239,7 +278,6 @@ int main(int argc, char* argv[]) {
}
}
std
::
cout
<<
std
::
endl
<<
std
::
flush
;
std
::
cout
<<
std
::
endl
<<
std
::
flush
;
return
!
pass
;
return
!
pass
;
}
}
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
04006d5f
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
...
@@ -15,49 +14,54 @@ template <typename DataType>
...
@@ -15,49 +14,54 @@ template <typename DataType>
struct
GemmBasicTypeConfig
;
struct
GemmBasicTypeConfig
;
template
<
>
template
<
>
struct
GemmBasicTypeConfig
<
ck_tile
::
half_t
>
{
struct
GemmBasicTypeConfig
<
ck_tile
::
half_t
>
using
XDataType
=
ck_tile
::
half_t
;
{
using
YDataType
=
ck_tile
::
half_t
;
using
XDataType
=
ck_tile
::
half_t
;
using
AccDataType
=
float
;
using
YDataType
=
ck_tile
::
half_t
;
using
ODataType
=
ck_tile
::
half_t
;
//type convert
using
AccDataType
=
float
;
using
ODataType
=
ck_tile
::
half_t
;
// type convert
// ToDo: Add more bias config to support different categories of GEMM.
// ToDo: Add more bias config to support different categories of GEMM.
};
};
template
<
ck_tile
::
MatrixALayout
A
,
ck_tile
::
MatrixBLayout
B
,
template
<
ck_tile
::
MatrixALayout
A
,
ck_tile
::
MatrixBLayout
B
,
ck_tile
::
MatrixCLayout
C
>
ck_tile
::
MatrixCLayout
C
>
struct
LayoutConfig
struct
LayoutConfig
{
{
static
constexpr
ck_tile
::
MatrixALayout
LayoutA
=
A
;
static
constexpr
ck_tile
::
MatrixALayout
LayoutA
=
A
;
static
constexpr
ck_tile
::
MatrixBLayout
LayoutB
=
B
;
static
constexpr
ck_tile
::
MatrixBLayout
LayoutB
=
B
;
static
constexpr
ck_tile
::
MatrixCLayout
LayoutC
=
C
;
static
constexpr
ck_tile
::
MatrixCLayout
LayoutC
=
C
;
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
DataTypeTraits
;
struct
DataTypeTraits
;
template
<
>
template
<
>
struct
DataTypeTraits
<
float
>
{
struct
DataTypeTraits
<
float
>
{
static
constexpr
const
char
*
name
=
"float"
;
static
constexpr
const
char
*
name
=
"float"
;
};
};
template
<
>
template
<
>
struct
DataTypeTraits
<
double
>
{
struct
DataTypeTraits
<
double
>
{
static
constexpr
const
char
*
name
=
"double"
;
static
constexpr
const
char
*
name
=
"double"
;
};
};
template
<
>
template
<
>
struct
DataTypeTraits
<
ck_tile
::
half_t
>
{
struct
DataTypeTraits
<
ck_tile
::
half_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
static
constexpr
const
char
*
name
=
"fp16"
;
};
};
using
Types
=
GemmBasicTypeConfig
<
ck_tile
::
half_t
>
;
using
Types
=
GemmBasicTypeConfig
<
ck_tile
::
half_t
>
;
// Specific type aliases for easy access
// Specific type aliases for easy access
using
XDataType
=
Types
::
XDataType
;
using
XDataType
=
Types
::
XDataType
;
using
YDataType
=
Types
::
YDataType
;
using
YDataType
=
Types
::
YDataType
;
using
AccDataType
=
Types
::
AccDataType
;
using
AccDataType
=
Types
::
AccDataType
;
using
ODataType
=
Types
::
ODataType
;
using
ODataType
=
Types
::
ODataType
;
struct
gemm_basic_args
{
struct
gemm_basic_args
{
const
void
*
p_x
;
const
void
*
p_x
;
const
void
*
p_y
;
const
void
*
p_y
;
void
*
p_z
;
void
*
p_z
;
...
...
include/ck_tile/host/check_err.hpp
View file @
04006d5f
...
@@ -67,7 +67,7 @@ check_err(const Range& out,
...
@@ -67,7 +67,7 @@ check_err(const Range& out,
int
err_count
=
0
;
int
err_count
=
0
;
double
err
=
0
;
double
err
=
0
;
double
max_err
=
std
::
numeric_limits
<
double
>::
min
();
double
max_err
=
std
::
numeric_limits
<
double
>::
min
();
for
(
std
::
size_t
i
=
419
0
;
i
<
ref
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
{
const
double
o
=
*
std
::
next
(
std
::
begin
(
out
),
i
);
const
double
o
=
*
std
::
next
(
std
::
begin
(
out
),
i
);
const
double
r
=
*
std
::
next
(
std
::
begin
(
ref
),
i
);
const
double
r
=
*
std
::
next
(
std
::
begin
(
ref
),
i
);
...
@@ -127,7 +127,7 @@ check_err(const Range& out,
...
@@ -127,7 +127,7 @@ check_err(const Range& out,
double
err
=
0
;
double
err
=
0
;
// TODO: This is a hack. We should have proper specialization for bf16_t data type.
// TODO: This is a hack. We should have proper specialization for bf16_t data type.
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
for
(
std
::
size_t
i
=
419
0
;
i
<
ref
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
));
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
));
...
@@ -186,7 +186,7 @@ check_err(const Range& out,
...
@@ -186,7 +186,7 @@ check_err(const Range& out,
int
err_count
=
0
;
int
err_count
=
0
;
double
err
=
0
;
double
err
=
0
;
double
max_err
=
static_cast
<
double
>
(
std
::
numeric_limits
<
ranges
::
range_value_t
<
Range
>>::
min
());
double
max_err
=
static_cast
<
double
>
(
std
::
numeric_limits
<
ranges
::
range_value_t
<
Range
>>::
min
());
for
(
std
::
size_t
i
=
419
0
;
i
<
ref
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
));
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
));
...
@@ -314,7 +314,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
...
@@ -314,7 +314,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
int
err_count
=
0
;
int
err_count
=
0
;
double
err
=
0
;
double
err
=
0
;
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
for
(
std
::
size_t
i
=
419
0
;
i
<
ref
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
{
const
fp8_t
o_fp8
=
*
std
::
next
(
std
::
begin
(
out
),
i
);
const
fp8_t
o_fp8
=
*
std
::
next
(
std
::
begin
(
out
),
i
);
const
fp8_t
r_fp8
=
*
std
::
next
(
std
::
begin
(
ref
),
i
);
const
fp8_t
r_fp8
=
*
std
::
next
(
std
::
begin
(
ref
),
i
);
...
@@ -372,7 +372,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
...
@@ -372,7 +372,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
int
err_count
=
0
;
int
err_count
=
0
;
double
err
=
0
;
double
err
=
0
;
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
for
(
std
::
size_t
i
=
419
0
;
i
<
ref
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
));
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
));
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
04006d5f
...
@@ -1144,17 +1144,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1144,17 +1144,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
KDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kK0
>>
;
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK0
>
,
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps_
,
Problem
::
BlockFmhaShape
::
Gemm0WarpTile_
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
...
@@ -1184,18 +1178,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1184,18 +1178,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
typename
Problem
::
OGradDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kVHeaddim
,
Problem
::
kBlockSize
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK1
>>
;
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK1
>
,
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps_
,
Problem
::
BlockFmhaShape
::
Gemm1WarpTile_
>>
;
using
WarpGemm
=
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
@@ -1217,18 +1204,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1217,18 +1204,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
BlockGemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
typename
Problem
::
VDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
kBlockSize
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK2
>>
;
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK2
>
,
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps_
,
Problem
::
BlockFmhaShape
::
Gemm2WarpTile_
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
OGradDataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
OGradDataType
,
half_t
>
&&
...
@@ -1295,18 +1275,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1295,18 +1275,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
typename
Problem
::
QDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kQKHeaddim
,
Problem
::
kBlockSize
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK3
>>
;
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK3
>
,
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps_
,
Problem
::
BlockFmhaShape
::
Gemm3WarpTile_
>>
;
using
WarpGemm
=
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
@@ -1328,18 +1301,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1328,18 +1301,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
typename
Problem
::
KDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kQKHeaddim
,
Problem
::
kBlockSize
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK4
>>
;
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK4
>
,
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps_
,
Problem
::
BlockFmhaShape
::
Gemm4WarpTile_
>>
;
using
WarpGemm
=
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
04006d5f
...
@@ -75,18 +75,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -75,18 +75,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kBlockSize
,
typename
Problem
::
KDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
typename
Problem
::
SaccDataType
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
kBlockSize
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK0
>>
;
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK0
>
,
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps_
,
Problem
::
BlockFmhaShape
::
Gemm0WarpTile_
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
...
@@ -202,18 +195,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
...
@@ -202,18 +195,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kBlockSize
,
typename
Problem
::
KDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
typename
Problem
::
SaccDataType
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
kBlockSize
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK0
>>
;
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK0
>
,
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps_
,
Problem
::
BlockFmhaShape
::
Gemm0WarpTile_
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
...
@@ -500,7 +486,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -500,7 +486,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template
<
typename
Problem
,
index_t
IBuf
=
0
>
template
<
typename
Problem
,
index_t
IBuf
=
0
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsStoreBlockDescriptor
(
number
<
IBuf
>
=
number
<
0
>
{})
MakeKLdsStoreBlockDescriptor
(
number
<
IBuf
>
=
number
<
0
>
{})
{
{
// K is always k-major, we use async-copy to load into LDS
// K is always k-major, we use async-copy to load into LDS
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
...
@@ -555,7 +541,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -555,7 +541,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
template
<
typename
Problem
,
index_t
IBuf
=
0
>
template
<
typename
Problem
,
index_t
IBuf
=
0
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsLoadBlockDescriptor
(
number
<
IBuf
>
=
number
<
0
>
{})
MakeKLdsLoadBlockDescriptor
(
number
<
IBuf
>
=
number
<
0
>
{})
{
{
// K is always k-major, we use async-copy to load into LDS
// K is always k-major, we use async-copy to load into LDS
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
...
@@ -950,18 +936,11 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -950,18 +936,11 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKVBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKVBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
PDataType
,
BlockGemmPipelineProblem
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
Problem
::
kBlockSize
,
typename
Problem
::
VDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
typename
Problem
::
OaccDataType
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kN1
,
Problem
::
kBlockSize
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK1
>>
;
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kN1
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK1
>
,
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps_
,
Problem
::
BlockFmhaShape
::
Gemm1WarpTile_
>>
;
auto
warp_gemm
=
[
&
]()
{
auto
warp_gemm
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
...
...
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
View file @
04006d5f
...
@@ -48,7 +48,9 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
...
@@ -48,7 +48,9 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
std
::
is_same_v
<
typename
Problem
::
CDataType
,
float
>
)
std
::
is_same_v
<
typename
Problem
::
CDataType
,
float
>
)
{
{
return
make_tuple
(
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
{},
4
,
1
);
return
make_tuple
(
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
{},
4
,
1
);
}
else
{
}
else
{
static_assert
(
false
,
"Unsupported data type configuration for GEMM warp execution."
);
static_assert
(
false
,
"Unsupported data type configuration for GEMM warp execution."
);
}
}
}
}
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
04006d5f
...
@@ -12,29 +12,34 @@
...
@@ -12,29 +12,34 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
,
typename
Layouts_
>
template
<
typename
TilePartitioner_
,
struct
GemmKernel
{
typename
GemmPipeline_
,
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
typename
EpiloguePipeline_
,
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
typename
Layouts_
>
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
struct
GemmKernel
using
Layouts
=
remove_cvref_t
<
Layouts_
>
;
{
static
constexpr
index_t
kBlockSize
=
GemmPipeline
::
kBlockSize
;
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
Layouts
=
remove_cvref_t
<
Layouts_
>
;
using
CAccDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
CDataType
>
;
static
constexpr
index_t
kBlockSize
=
GemmPipeline
::
kBlockSize
;
using
CODataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
__host__
static
constexpr
auto
GridSize
(
index_t
M_size
,
index_t
N_size
,
index_t
Batch_size
)
{
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
CAccDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
CDataType
>
;
using
CODataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
__host__
static
constexpr
auto
GridSize
(
index_t
M_size
,
index_t
N_size
,
index_t
Batch_size
)
{
auto
x
=
TilePartitioner
::
GridSize
(
M_size
,
N_size
,
Batch_size
);
auto
x
=
TilePartitioner
::
GridSize
(
M_size
,
N_size
,
Batch_size
);
printf
(
"GridDimX: %d, GridDimY: %d, %d"
,
x
.
x
,
x
.
y
,
x
.
z
);
printf
(
"GridDimX: %d, GridDimY: %d, %d"
,
x
.
x
,
x
.
y
,
x
.
z
);
return
TilePartitioner
::
GridSize
(
M_size
,
N_size
,
Batch_size
);
return
TilePartitioner
::
GridSize
(
M_size
,
N_size
,
Batch_size
);
}
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
struct
GemmCommonKargs
{
struct
GemmCommonKargs
{
const
void
*
a_ptr
;
const
void
*
a_ptr
;
const
void
*
b_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
void
*
c_ptr
;
...
@@ -60,15 +65,19 @@ struct GemmKernel {
...
@@ -60,15 +65,19 @@ struct GemmKernel {
ck_tile
::
index_t
K
,
ck_tile
::
index_t
K
,
ck_tile
::
index_t
stride_A
,
ck_tile
::
index_t
stride_A
,
ck_tile
::
index_t
stride_B
,
ck_tile
::
index_t
stride_B
,
ck_tile
::
index_t
stride_C
)
{
ck_tile
::
index_t
stride_C
)
return
GemmCommonKargs
{
a_ptr
,
b_ptr
,
c_ptr
,
epsilon
,
batch_size
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
};
{
return
GemmCommonKargs
{
a_ptr
,
b_ptr
,
c_ptr
,
epsilon
,
batch_size
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
};
}
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
ck_tile
::
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
return
ck_tile
::
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
}
CK_TILE_DEVICE
void
operator
()(
GemmCommonKargs
kargs
)
const
{
CK_TILE_DEVICE
void
operator
()(
GemmCommonKargs
kargs
)
const
{
const
auto
[
i_tile_m
,
i_tile_n
,
i_batch
]
=
TilePartitioner
{}();
const
auto
[
i_tile_m
,
i_tile_n
,
i_batch
]
=
TilePartitioner
{}();
const
index_t
i_m
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
TilePartitioner
::
kM
);
const
index_t
i_m
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
TilePartitioner
::
kM
);
const
index_t
i_n
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
TilePartitioner
::
kN
);
const
index_t
i_n
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
TilePartitioner
::
kN
);
...
@@ -76,66 +85,96 @@ struct GemmKernel {
...
@@ -76,66 +85,96 @@ struct GemmKernel {
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
// Convert pointers to tensor views
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
](){
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
Layouts
::
LayoutA
==
ck_tile
::
MatrixALayout
::
KM
)
{
if
constexpr
(
Layouts
::
LayoutA
==
ck_tile
::
MatrixALayout
::
KM
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
a_start
,
number
<
GemmPipeline
::
AlignmentA
>
{},
number
<
1
>
{});
make_tuple
(
kargs
.
M
,
kargs
.
K
),
}
else
{
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
GemmPipeline
::
AlignmentA
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
a_start
,
number
<
GemmPipeline
::
AlignmentA
>
{},
number
<
1
>
{});
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
AlignmentA
>
{},
number
<
1
>
{});
}
}
}();
}();
auto
b_tensor_view
=
[
&
](){
auto
b_tensor_view
=
[
&
]()
{
if
constexpr
(
Layouts
::
LayoutB
==
ck_tile
::
MatrixBLayout
::
KN
)
{
if
constexpr
(
Layouts
::
LayoutB
==
ck_tile
::
MatrixBLayout
::
KN
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_B
),
b_start
,
number
<
GemmPipeline
::
AlignmentB
>
{},
number
<
1
>
{});
make_tuple
(
kargs
.
N
,
kargs
.
K
),
}
else
{
// Default NK layout
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
GemmPipeline
::
AlignmentB
>
{},
number
<
1
>
{});
}
else
{
// Default NK layout
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_B
,
1
),
b_start
,
number
<
GemmPipeline
::
AlignmentB
>
{},
number
<
1
>
{});
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
AlignmentB
>
{},
number
<
1
>
{});
}
}
}();
}();
auto
ABlockWindow
=
make_tile_window
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
auto
ABlockWindow
=
make_tile_window
(
auto
BBlockWindow
=
make_tile_window
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
a_tensor_view
,
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
auto
BBlockWindow
=
make_tile_window
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
// allocate LDS
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
index_t
num_loop
=
(
kargs
.
K
+
TilePartitioner
::
kK
-
1
)
/
TilePartitioner
::
kK
;
const
index_t
num_loop
=
(
kargs
.
K
+
TilePartitioner
::
kK
-
1
)
/
TilePartitioner
::
kK
;
auto
acc
=
BlockGemmPipelineAGmemBGmemCRegV1
<
GemmPipeline
>
{}(
auto
acc
=
BlockGemmPipelineAGmemBGmemCRegV1
<
GemmPipeline
>
{}(
ABlockWindow
,
BBlockWindow
,
num_loop
,
smem_ptr
);
ABlockWindow
,
BBlockWindow
,
num_loop
,
smem_ptr
);
CODataType
*
c_start
=
static_cast
<
CODataType
*>
(
kargs
.
c_ptr
);
CODataType
*
c_start
=
static_cast
<
CODataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
](){
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
Layouts
::
LayoutC
==
ck_tile
::
MatrixCLayout
::
NM
){
if
constexpr
(
Layouts
::
LayoutC
==
ck_tile
::
MatrixCLayout
::
NM
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
c_start
,
number
<
GemmPipeline
::
AlignmentC
>
{},
number
<
1
>
{});
make_tuple
(
kargs
.
M
,
kargs
.
N
),
}
else
{
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
GemmPipeline
::
AlignmentC
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
c_start
,
number
<
GemmPipeline
::
AlignmentC
>
{},
number
<
1
>
{});
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
AlignmentC
>
{},
number
<
1
>
{});
}
}
}();
}();
auto
CBlockWindow
=
make_tile_window
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
auto
CBlockWindow
=
make_tile_window
(
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
// epilogue.
// epilogue.
EpiloguePipeline
{}(
CBlockWindow
,
acc
);
EpiloguePipeline
{}(
CBlockWindow
,
acc
);
}
}
};
};
}
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/gemm_matrix_type.hpp
View file @
04006d5f
...
@@ -4,36 +4,21 @@
...
@@ -4,36 +4,21 @@
#pragma once
#pragma once
namespace
ck_tile
{
namespace
ck_tile
{
enum
struct
MatrixALayout
{
enum
struct
MatrixALayout
MK
,
// Row-major layout for matrix A (default)
{
KM
// Column-major layout for matrix A
MK
,
// Row-major layout for matrix A (default)
};
KM
// Column-major layout for matrix A
};
enum
struct
MatrixBLayout
{
enum
struct
MatrixBLayout
NK
,
// Row-major layout for matrix B (default)
{
KN
// Column-major layout for matrix B
NK
,
// Row-major layout for matrix B (default)
};
KN
// Column-major layout for matrix B
};
enum
struct
MatrixCLayout
{
enum
struct
MatrixCLayout
MN
,
// Row-major layout for matrix C (default)
{
NM
// Column-major layout for matrix C
MN
,
// Row-major layout for matrix C (default)
};
NM
// Column-major layout for matrix C
};
// Function to convert string to MatrixALayout
}
// namespace ck_tile
inline
MatrixALayout
parse_layout_a
(
const
std
::
string
&
layout
)
{
if
(
layout
==
"KM"
)
return
MatrixALayout
::
KM
;
return
MatrixALayout
::
MK
;
// Default to MK if not specified as KM
}
// Function to convert string to MatrixBLayout
inline
MatrixBLayout
parse_layout_b
(
const
std
::
string
&
layout
)
{
if
(
layout
==
"KN"
)
return
MatrixBLayout
::
KN
;
return
MatrixBLayout
::
NK
;
// Default to NK if not specified as KN
}
// Function to convert string to MatrixBLayout
inline
MatrixCLayout
parse_layout_c
(
const
std
::
string
&
layout
)
{
if
(
layout
==
"NM"
)
return
MatrixCLayout
::
NM
;
return
MatrixCLayout
::
MN
;
// Default to MN if not specified as NM
}
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
View file @
04006d5f
...
@@ -6,27 +6,30 @@
...
@@ -6,27 +6,30 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
BlockGemmShape_
>
template
<
typename
BlockGemmShape_
>
struct
GemmTilePartitioner
{
struct
GemmTilePartitioner
using
BlockGemmShape
=
ck_tile
::
remove_cvref_t
<
BlockGemmShape_
>
;
{
using
BlockGemmShape
=
ck_tile
::
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
ck_tile
::
index_t
kM
=
BlockGemmShape
::
kM
;
static
constexpr
ck_tile
::
index_t
kM
=
BlockGemmShape
::
kM
;
static
constexpr
ck_tile
::
index_t
kN
=
BlockGemmShape
::
kN
;
static
constexpr
ck_tile
::
index_t
kN
=
BlockGemmShape
::
kN
;
static
constexpr
ck_tile
::
index_t
kK
=
BlockGemmShape
::
kK
;
static
constexpr
ck_tile
::
index_t
kK
=
BlockGemmShape
::
kK
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
,
CK_TILE_HOST
static
constexpr
auto
ck_tile
::
index_t
batch_size
)
{
GridSize
(
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
,
ck_tile
::
index_t
batch_size
)
ck_tile
::
index_t
GridDimX
=
(
M
+
kM
-
1
)
/
kM
;
{
ck_tile
::
index_t
GridDimY
=
(
N
+
kN
-
1
)
/
kN
;
ck_tile
::
index_t
GridDimX
=
(
M
+
kM
-
1
)
/
kM
;
ck_tile
::
index_t
GridDimZ
=
batch_size
;
ck_tile
::
index_t
GridDimY
=
(
N
+
kN
-
1
)
/
kN
;
return
dim3
(
GridDimX
,
GridDimY
,
GridDimZ
);
ck_tile
::
index_t
GridDimZ
=
batch_size
;
}
return
dim3
(
GridDimX
,
GridDimY
,
GridDimZ
);
}
CK_TILE_DEVICE
auto
operator
()()
{
CK_TILE_DEVICE
auto
operator
()()
const
index_t
i_GridDimX
=
blockIdx
.
x
;
{
const
index_t
i_GridDimY
=
blockIdx
.
y
;
const
index_t
i_GridDimX
=
blockIdx
.
x
;
const
index_t
i_GridDimZ
=
blockIdx
.
z
;
const
index_t
i_GridDimY
=
blockIdx
.
y
;
return
ck_tile
::
make_tuple
(
i_GridDimX
,
i_GridDimY
,
i_GridDimZ
);
const
index_t
i_GridDimZ
=
blockIdx
.
z
;
}
return
ck_tile
::
make_tuple
(
i_GridDimX
,
i_GridDimY
,
i_GridDimZ
);
};
}
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
04006d5f
...
@@ -40,7 +40,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
...
@@ -40,7 +40,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
...
@@ -149,7 +150,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
...
@@ -149,7 +150,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
}
}
index_t
iCounter
=
num_loop
-
1
;
index_t
iCounter
=
num_loop
-
1
;
while
(
iCounter
>
0
)
{
while
(
iCounter
>
0
)
{
// global read i + 1
// global read i + 1
a_block_tile
=
load_tile
(
a_copy_dram_window
);
a_block_tile
=
load_tile
(
a_copy_dram_window
);
b_block_tile
=
load_tile
(
b_copy_dram_window
);
b_block_tile
=
load_tile
(
b_copy_dram_window
);
...
@@ -174,7 +176,6 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
...
@@ -174,7 +176,6 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
iCounter
--
;
iCounter
--
;
}
}
// tail
// tail
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
04006d5f
...
@@ -91,26 +91,29 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -91,26 +91,29 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
return
b_lds_block_desc
;
return
b_lds_block_desc
;
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeA
()
{
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeA
()
{
constexpr
index_t
smem_size_a
=
sizeof
(
typename
Problem
::
ADataType
)
*
constexpr
index_t
smem_size_a
=
sizeof
(
typename
Problem
::
ADataType
)
*
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_a
;
return
smem_size_a
;
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeB
()
{
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeB
()
{
constexpr
index_t
smem_size_b
=
sizeof
(
typename
Problem
::
BDataType
)
*
constexpr
index_t
smem_size_b
=
sizeof
(
typename
Problem
::
BDataType
)
*
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_b
;
return
smem_size_b
;
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
{
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
index_t
smem_size
=
0
;
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
index_t
smem_size
=
0
;
smem_size
+=
smem_size_a
+
smem_size_b
;
smem_size
+=
smem_size_a
+
smem_size_b
;
return
smem_size
;
return
smem_size
;
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
View file @
04006d5f
...
@@ -22,14 +22,13 @@ struct BlockGemmPipelineProblem
...
@@ -22,14 +22,13 @@ struct BlockGemmPipelineProblem
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
64
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
64
;
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadC
=
kPadC_
;
static
constexpr
bool
kPadC
=
kPadC_
;
static
constexpr
index_t
AlignmentA
=
kPadA
?
16
/
sizeof
(
ADataType
)
:
1
;
static
constexpr
index_t
AlignmentA
=
kPadA
?
16
/
sizeof
(
ADataType
)
:
1
;
static
constexpr
index_t
AlignmentB
=
kPadB
?
16
/
sizeof
(
BDataType
)
:
1
;
static
constexpr
index_t
AlignmentB
=
kPadB
?
16
/
sizeof
(
BDataType
)
:
1
;
static
constexpr
index_t
AlignmentC
=
kPadC
?
16
/
sizeof
(
CDataType
)
:
1
;
static
constexpr
index_t
AlignmentC
=
kPadC
?
16
/
sizeof
(
CDataType
)
:
1
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
View file @
04006d5f
...
@@ -7,17 +7,21 @@
...
@@ -7,17 +7,21 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
BlockTile_
,
template
<
index_t
kMPerTile
,
index_t
kNPerTile
,
index_t
kKPerTile
>
typename
BlockWarps_
,
struct
TileGemmShape
{
typename
WarpTile_
>
static
constexpr
index_t
kM
=
kMPerTile
;
struct
TileGemmShape
static
constexpr
index_t
kN
=
kNPerTile
;
static
constexpr
index_t
kK
=
kKPerTile
;
};
template
<
typename
BlockTile_
,
typename
BlockWarps_
,
typename
WarpTile_
>
struct
TileGemmShapeNewGemm
{
{
using
BlockTile
=
remove_cvref_t
<
BlockTile_
>
;
using
BlockTile
=
remove_cvref_t
<
BlockTile_
>
;
using
BlockWarps
=
remove_cvref_t
<
BlockWarps_
>
;
using
BlockWarps
=
remove_cvref_t
<
BlockWarps_
>
;
using
WarpTile
=
remove_cvref_t
<
WarpTile_
>
;
using
WarpTile
=
remove_cvref_t
<
WarpTile_
>
;
static
constexpr
index_t
NumWarps
=
static
constexpr
index_t
NumWarps
=
reduce_on_sequence
(
BlockWarps
{},
multiplies
{},
number
<
1
>
{});
reduce_on_sequence
(
BlockWarps
{},
multiplies
{},
number
<
1
>
{});
static
constexpr
index_t
kM
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kM
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kN
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kN
=
BlockTile
::
at
(
number
<
1
>
{});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment