Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1639689e
Commit
1639689e
authored
May 15, 2023
by
carlushuang
Browse files
update example
parent
e730aeb7
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
215 additions
and
57 deletions
+215
-57
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+1
-0
example/01_gemm/common.hpp
example/01_gemm/common.hpp
+73
-2
example/01_gemm/gemm_xdl_fp16.cpp
example/01_gemm/gemm_xdl_fp16.cpp
+11
-2
example/01_gemm/gemm_xdl_streamk.cpp
example/01_gemm/gemm_xdl_streamk.cpp
+11
-3
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+105
-38
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
...sor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+2
-2
profiler/include/profiler/profile_gemm_streamk_impl.hpp
profiler/include/profiler/profile_gemm_streamk_impl.hpp
+5
-4
profiler/src/profile_gemm_streamk.cpp
profiler/src/profile_gemm_streamk.cpp
+5
-4
No files found.
example/01_gemm/CMakeLists.txt
View file @
1639689e
...
@@ -45,3 +45,4 @@ if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS
...
@@ -45,3 +45,4 @@ if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS
endif
()
endif
()
add_example_executable
(
example_gemm_xdl_streamk gemm_xdl_streamk.cpp
)
add_example_executable
(
example_gemm_xdl_streamk gemm_xdl_streamk.cpp
)
set_source_files_properties
(
gemm_xdl_streamk.cpp PROPERTIES COMPILE_OPTIONS
"-v;--save-temps;-Wno-gnu-line-marker"
)
\ No newline at end of file
example/01_gemm/common.hpp
View file @
1639689e
...
@@ -33,6 +33,19 @@ struct ProblemSize final
...
@@ -33,6 +33,19 @@ struct ProblemSize final
ck
::
index_t
StrideC
=
4096
;
ck
::
index_t
StrideC
=
4096
;
};
};
struct
ProblemSizeStreamK
final
{
ck
::
index_t
M
=
3840
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
StrideA
=
4096
;
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
StrideC
=
4096
;
ck
::
index_t
NumSKBlocks
=
-
1
;
};
struct
ExecutionConfig
final
struct
ExecutionConfig
final
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
...
@@ -48,8 +61,17 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
...
@@ -48,8 +61,17 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
inline
bool
template
<
typename
ProblemType
>
parse_cmd_args
(
int
argc
,
char
*
argv
[],
ProblemSize
&
problem_size
,
ExecutionConfig
&
config
)
bool
parse_cmd_args
(
int
,
char
*
[],
ProblemType
&
,
ExecutionConfig
&
)
{
return
false
;
}
template
<
>
bool
parse_cmd_args
<
ProblemSize
>
(
int
argc
,
char
*
argv
[],
ProblemSize
&
problem_size
,
ExecutionConfig
&
config
)
{
{
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
...
@@ -87,3 +109,52 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi
...
@@ -87,3 +109,52 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi
return
true
;
return
true
;
}
}
template
<
>
bool
parse_cmd_args
<
ProblemSizeStreamK
>
(
int
argc
,
char
*
argv
[],
ProblemSizeStreamK
&
problem_size
,
ExecutionConfig
&
config
)
{
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
>=
10
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
problem_size
.
M
=
std
::
stoi
(
argv
[
4
]);
problem_size
.
N
=
std
::
stoi
(
argv
[
5
]);
problem_size
.
K
=
std
::
stoi
(
argv
[
6
]);
problem_size
.
StrideA
=
std
::
stoi
(
argv
[
7
]);
problem_size
.
StrideB
=
std
::
stoi
(
argv
[
8
]);
problem_size
.
StrideC
=
std
::
stoi
(
argv
[
9
]);
if
(
argc
>=
11
)
{
problem_size
.
NumSKBlocks
=
std
::
stoi
(
argv
[
10
]);
}
}
else
{
std
::
cerr
<<
"arg1: verification (0=no, 1=yes)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC"
<<
std
::
endl
<<
"arg10: NumSKBlocks(optional)"
<<
std
::
endl
;
return
false
;
}
return
true
;
}
example/01_gemm/gemm_xdl_fp16.cpp
View file @
1639689e
...
@@ -23,6 +23,7 @@ using BElementOp = PassThrough;
...
@@ -23,6 +23,7 @@ using BElementOp = PassThrough;
using
CElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
// clang-format off
using
DeviceGemmInstance0
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdl
using
DeviceGemmInstance0
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdl
...
@@ -30,7 +31,10 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl
...
@@ -30,7 +31,10 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
;
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
,
1
,
ck
::
LoopScheduler
::
Interwave
,
ck
::
PipelineVersion
::
v1
>
;
// // clang-format on
// // clang-format on
// clang-format off
// clang-format off
...
@@ -39,7 +43,12 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl
...
@@ -39,7 +43,12 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
LoopScheduler
::
Default
,
ck
::
PipelineVersion
::
v1
>
;
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
LoopScheduler
::
Default
,
ck
::
PipelineVersion
::
v1
>
;
// < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v1>;
// DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
// DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
// clang-format on
// clang-format on
...
...
example/01_gemm/gemm_xdl_streamk.cpp
View file @
1639689e
...
@@ -14,7 +14,8 @@ using CDataType = ck::half_t;
...
@@ -14,7 +14,8 @@ using CDataType = ck::half_t;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
ALayout
=
Row
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
BLayout
=
Row
;
// using BLayout = Col;
using
CLayout
=
Row
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
...
@@ -27,7 +28,14 @@ using DeviceGemmStreamK = ck::tensor_operation::device::DeviceGemmXdlStreamK
...
@@ -27,7 +28,14 @@ using DeviceGemmStreamK = ck::tensor_operation::device::DeviceGemmXdlStreamK
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8>;
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
128
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
;
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 128, 4, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8>;
// // clang-format on
// // clang-format on
// clang-format on
// clang-format on
...
@@ -38,4 +46,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
...
@@ -38,4 +46,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
#include "run_gemm_example.inc"
#include "run_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_
streamk_
example
(
argc
,
argv
);
}
example/01_gemm/run_gemm_example.inc
View file @
1639689e
...
@@ -3,7 +3,10 @@
...
@@ -3,7 +3,10 @@
#pragma once
#pragma once
bool
run_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp"
template
<
typename
ProblemType
>
bool
run_gemm
(
const
ProblemType
&
problem_size
,
const
ExecutionConfig
&
config
)
{
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert
(
sizeof
(
ck
::
int4_t
)
==
sizeof
(
int8_t
));
static_assert
(
sizeof
(
ck
::
int4_t
)
==
sizeof
(
int8_t
));
...
@@ -11,7 +14,12 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -11,7 +14,12 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
using
namespace
ck
::
literals
;
using
namespace
ck
::
literals
;
auto
[
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
]
=
problem_size
;
auto
M
=
problem_size
.
M
;
auto
N
=
problem_size
.
N
;
auto
K
=
problem_size
.
K
;
auto
StrideA
=
problem_size
.
StrideA
;
auto
StrideB
=
problem_size
.
StrideB
;
auto
StrideC
=
problem_size
.
StrideC
;
auto
f_host_tensor_descriptor
=
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
...
@@ -25,21 +33,23 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -25,21 +33,23 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
}
}
};
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
auto
f_get_default_stride
=
if
(
stride
==
0
)
{
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
// give a chance if stride is zero, return a defalt packed stride
if
(
stride
==
0
)
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
return
col
;
// give a chance if stride is zero, return a defalt packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
}
else
else
{
return
stride
;
return
row
;
};
}
}
else
return
stride
;
};
StrideA
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
StrideA
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
StrideB
=
f_get_default_stride
(
K
,
N
,
StrideB
,
BLayout
{});
StrideB
=
f_get_default_stride
(
K
,
N
,
StrideB
,
BLayout
{});
...
@@ -94,37 +104,86 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -94,37 +104,86 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
CElementOp
{};
using
BaseStreamK
=
ck
::
tensor_operation
::
device
::
DeviceGemmStreamK
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
// do GEMM
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
float
ave_time
=
0
;
if
constexpr
(
std
::
is_same
<
ProblemType
,
ProblemSize
>::
value
&&
!
std
::
is_base_of
<
BaseStreamK
,
DeviceGemmInstance
>::
value
)
{
auto
argument
=
gemm
.
MakeArgument
(
#ifdef BUILD_INT4_EXAMPLE
#ifdef BUILD_INT4_EXAMPLE
static_cast
<
KernelADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelBDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelBDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelCDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelCDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
#else
#else
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
#endif
#endif
M
,
M
,
N
,
N
,
K
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
);
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
std
::
cerr
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
std
::
cerr
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
true
;
}
return
true
;
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
})
;
}
}
else
if
constexpr
(
std
::
is_same
<
ProblemType
,
ProblemSizeStreamK
>::
value
&&
std
::
is_base_of
<
BaseStreamK
,
DeviceGemmInstance
>::
value
)
{
auto
argument
=
gemm
.
MakeArgument
(
#ifdef BUILD_INT4_EXAMPLE
static_cast
<
KernelADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelBDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelCDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
#else
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
#endif
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
problem_size
.
NumSKBlocks
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cerr
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
true
;
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
}
std
::
size_t
flop
=
2_
uz
*
M
*
N
*
K
;
std
::
size_t
flop
=
2_
uz
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
std
::
size_t
num_btype
=
...
@@ -172,3 +231,11 @@ bool run_gemm_example(int argc, char* argv[])
...
@@ -172,3 +231,11 @@ bool run_gemm_example(int argc, char* argv[])
return
!
parse_cmd_args
(
argc
,
argv
,
problem_size
,
config
)
||
run_gemm
(
problem_size
,
config
);
return
!
parse_cmd_args
(
argc
,
argv
,
problem_size
,
config
)
||
run_gemm
(
problem_size
,
config
);
}
}
bool
run_gemm_streamk_example
(
int
argc
,
char
*
argv
[])
{
ProblemSizeStreamK
problem_size
;
ExecutionConfig
config
;
return
!
parse_cmd_args
(
argc
,
argv
,
problem_size
,
config
)
||
run_gemm
(
problem_size
,
config
);
}
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
View file @
1639689e
...
@@ -186,7 +186,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
...
@@ -186,7 +186,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
in
dex
_t
NumSKBlocks
=
0
)
u
in
t32
_t
NumSKBlocks
=
0
xffffffff
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
const
auto
kernel
=
kernel_gemm_xdlops_streamk
<
GridwiseGemm
>
;
int
occupancy
,
num_cu
;
int
occupancy
,
num_cu
;
...
@@ -214,7 +214,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
...
@@ -214,7 +214,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
StrideC
,
StrideC
,
static_cast
<
uint32_t
>
(
num_cu
),
static_cast
<
uint32_t
>
(
num_cu
),
static_cast
<
uint32_t
>
(
occupancy
),
static_cast
<
uint32_t
>
(
occupancy
),
static_cast
<
uint32_t
>
(
NumSKBlocks
)
};
NumSKBlocks
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
1639689e
...
@@ -676,7 +676,7 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -676,7 +676,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t
k
,
uint32_t
k
,
uint32_t
num_cu
,
uint32_t
num_cu
,
uint32_t
occupancy
,
uint32_t
occupancy
,
uint32_t
sk_blocks
=
0
,
uint32_t
sk_blocks
=
0
xffffffff
,
uint32_t
tile_swizzle_sub_m_factor
=
8
)
uint32_t
tile_swizzle_sub_m_factor
=
8
)
{
{
uint32_t
num_tiles
=
uint32_t
num_tiles
=
...
@@ -777,7 +777,7 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -777,7 +777,7 @@ struct BlockToCTileMap_GemmStreamK
}
}
// give a chance to control num of sk blocks
// give a chance to control num of sk blocks
sk_num_blocks
=
sk_blocks
!=
0
?
sk_blocks
:
sk_num_blocks
;
sk_num_blocks
=
sk_blocks
!=
0
xffffffff
?
sk_blocks
:
sk_num_blocks
;
sk_num_blocks
=
env_get_int
(
"sk_num_blocks"
,
sk_num_blocks
);
sk_num_blocks
=
env_get_int
(
"sk_num_blocks"
,
sk_num_blocks
);
if
(
sk_num_blocks
==
0
)
if
(
sk_num_blocks
==
0
)
...
...
profiler/include/profiler/profile_gemm_streamk_impl.hpp
View file @
1639689e
...
@@ -41,7 +41,7 @@ bool profile_gemm_streamk_impl(int do_verification,
...
@@ -41,7 +41,7 @@ bool profile_gemm_streamk_impl(int do_verification,
int
StrideA
,
int
StrideA
,
int
StrideB
,
int
StrideB
,
int
StrideC
,
int
StrideC
,
int
NumSKBlocks
=
0
)
u
int
32_t
NumSKBlocks
=
0
xffffffff
)
{
{
bool
pass
=
true
;
bool
pass
=
true
;
...
@@ -72,8 +72,8 @@ bool profile_gemm_streamk_impl(int do_verification,
...
@@ -72,8 +72,8 @@ bool profile_gemm_streamk_impl(int do_verification,
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
0
,
1
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
1
,
1
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
3
,
3
});
break
;
break
;
default:
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
...
@@ -110,7 +110,8 @@ bool profile_gemm_streamk_impl(int do_verification,
...
@@ -110,7 +110,8 @@ bool profile_gemm_streamk_impl(int do_verification,
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances, "
<<
(
do_verification
?
"with verification"
:
"without verification"
)
<<
std
::
endl
;
// Run reference GEMM
// Run reference GEMM
if
(
do_verification
)
if
(
do_verification
)
...
...
profiler/src/profile_gemm_streamk.cpp
View file @
1639689e
...
@@ -58,10 +58,11 @@ int profile_gemm_streamk(int argc, char* argv[])
...
@@ -58,10 +58,11 @@ int profile_gemm_streamk(int argc, char* argv[])
const
int
N
=
std
::
stoi
(
argv
[
9
]);
const
int
N
=
std
::
stoi
(
argv
[
9
]);
const
int
K
=
std
::
stoi
(
argv
[
10
]);
const
int
K
=
std
::
stoi
(
argv
[
10
]);
const
int
StrideA
=
std
::
stoi
(
argv
[
11
]);
const
int
StrideA
=
std
::
stoi
(
argv
[
11
]);
const
int
StrideB
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideB
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
const
int
NumSKBlocks
=
argc
>=
15
?
std
::
stoi
(
argv
[
14
])
:
0
;
const
uint32_t
NumSKBlocks
=
argc
>=
15
?
static_cast
<
uint32_t
>
(
std
::
stoul
(
std
::
string
(
argv
[
14
])))
:
0xffffffff
;
using
F32
=
float
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment