Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
49facb91
Commit
49facb91
authored
Nov 07, 2023
by
Harisankar Sadasivan
Browse files
files for gemv and tall and skinny gemm examples and corresponding entries to ckprofiler
parent
98fd41f5
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1113 additions
and
0 deletions
+1113
-0
library/src/tensor_operation_instance/gpu/tall_and_skinny_gemm_splitk/device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
..._and_skinny_gemm_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
+199
-0
profiler/include/profiler/profile_gemv_splitk_impl.hpp
profiler/include/profiler/profile_gemv_splitk_impl.hpp
+297
-0
profiler/include/profiler/profile_tall_and_skinny_gemm_splitk_impl.hpp
...ude/profiler/profile_tall_and_skinny_gemm_splitk_impl.hpp
+297
-0
profiler/src/profile_gemv_splitk.cpp
profiler/src/profile_gemv_splitk.cpp
+160
-0
profiler/src/profile_tall_and_skinny_gemm_splitk.cpp
profiler/src/profile_tall_and_skinny_gemm_splitk.cpp
+160
-0
No files found.
library/src/tensor_operation_instance/gpu/tall_and_skinny_gemm_splitk/device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
0 → 100755
View file @
49facb91
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_nk_mn_instances
=
std
::
tuple
<
// clang-format off
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
///< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, M1, B*N1, K0, K1, M1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1,M1, 1>,S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, 4, K1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
//M1 is always tied to 16
//N1=2
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
1
,
2
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
1
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
1
,
4
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
1
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
1
,
8
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
1
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
2
,
2
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
2
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
2
,
4
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
2
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
2
,
8
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
2
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
3
,
2
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
3
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
3
,
4
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
3
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
3
,
8
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
3
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
4
,
2
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
4
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
4
,
4
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
4
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
128
,
4
,
8
,
16
,
2
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
4
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
2
>
,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 5, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 5, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 5, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 6, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 6, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 6, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 7, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 7, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 7, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 8, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 8, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 8, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// //N1=4
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
1
,
2
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
1
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
1
,
4
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
1
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
1
,
8
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
1
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
2
,
2
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
2
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
2
,
4
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
2
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
2
,
8
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
2
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
3
,
2
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
3
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
3
,
4
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
3
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
3
,
8
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
3
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
4
,
2
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
4
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
4
,
4
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
4
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
256
,
4
,
8
,
16
,
4
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
4
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 5, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 5, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 5, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 6, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 6, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 6, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 7, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 7, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 7, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 8, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 8, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 8, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// //N1=8
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
1
,
2
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
1
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
1
,
4
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
1
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
1
,
8
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
1
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
2
,
2
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
2
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
2
,
4
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
2
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
2
,
8
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
2
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
3
,
2
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
3
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
3
,
4
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
3
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
3
,
8
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
3
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
4
,
2
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
1
,
4
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
2
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
4
,
4
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
1
,
4
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
,
ck
::
tensor_operation
::
device
::
deviceTsmmDl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
64
,
16
,
512
,
4
,
8
,
16
,
8
,
1
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
1
,
4
,
1
,
16
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
S
<
1
,
1
,
1
,
1
,
8
>
,
S
<
0
,
1
,
2
,
3
,
4
>
,
4
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
8
>
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 5, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 5, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 5, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 6, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 6, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 6, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 7, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 7, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 7, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 8, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 8, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 8, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>
// clang-format on
>
;
void
add_device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceTsmm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_nk_mn_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
profiler/include/profiler/profile_gemv_splitk_impl.hpp
0 → 100755
View file @
49facb91
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include <iostream>
#include <typeinfo>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_tall_and_skinny_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemv_splitk.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace
ck
{
namespace
profiler
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
bool
profile_gemv_splitk_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
bool
time_kernel
,
int
M
,
int
N
,
int
K
,
int
StrideA
,
int
StrideB
,
int
StrideC
,
int
KBatch
)
{
bool
pass
=
true
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
using
namespace
ck
::
literals
;
if
(
is_same
<
decltype
(
layout
),
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_device_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
1
,
2
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
1
,
2
});
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
}
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceTsmm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
// Run reference GEMM
if
(
do_verification
)
{
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
}
std
::
string
best_op_name
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
float
best_kbatch
=
0
;
// profile device GEMM instances
for
(
auto
&
op_ptr
:
op_ptrs
)
{
std
::
vector
<
int
>
kbatch_list
=
{
1
,
2
,
4
,
8
,
12
,
16
,
20
,
24
,
32
,
36
,
40
,
60
,
64
,
72
,
80
,
88
,
96
,
128
,
144
,
160
,
176
,
192
,
256
};
if
(
KBatch
>
0
)
{
kbatch_list
=
{
KBatch
};
}
for
(
std
::
size_t
i
=
0
;
i
<
kbatch_list
.
size
();
i
++
)
{
auto
kbatch_curr
=
kbatch_list
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
kbatch_curr
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
// re-init C to zero before profiling next kernel
c_device_buf
.
SetZero
();
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
if
(
do_verification
)
{
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a_m_k
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b_k_n
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host : "
,
c_m_n_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
c_m_n_device_result
.
mData
,
","
)
<<
std
::
endl
;
}
}
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
", KBatch "
<<
kbatch_curr
<<
std
::
endl
;
#if defined CK_ENABLE_FP8
// set softer tolerances for fp8
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
||
is_same_v
<
BDataType
,
f8_t
>
||
is_same_v
<
CDataType
,
f8_t
>
)
{
std
::
string
msg
=
"Error: Incorrect results!"
;
double
rtol
=
1e-1
;
double
atol
=
1e-1
;
pass
=
pass
&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
msg
,
rtol
,
atol
);
}
else
{
#endif
pass
=
pass
&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
#if defined CK_ENABLE_FP8
}
#endif
if
(
tflops
>
best_tflops
)
{
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
best_kbatch
=
kbatch_curr
;
}
}
else
{
std
::
cout
<<
op_ptr
->
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
}
}
}
if
constexpr
(
is_same
<
CDataType
,
float
>::
value
)
{
std
::
cout
<<
"Best Perf for datatype = f32"
;
}
else
if
constexpr
(
is_same
<
CDataType
,
half_t
>::
value
)
{
std
::
cout
<<
"Best Perf for datatype = f16"
;
}
else
if
constexpr
(
is_same
<
CDataType
,
bhalf_t
>::
value
)
{
std
::
cout
<<
"Best Perf for datatype = bf16"
;
}
else
if
constexpr
(
is_same
<
CDataType
,
int8_t
>::
value
)
{
std
::
cout
<<
"Best Perf for datatype = int8"
;
}
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
std
::
cout
<<
" ALayout = RowMajor"
;
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
)
{
std
::
cout
<<
" ALayout = ColumnMajor"
;
}
if
constexpr
(
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
std
::
cout
<<
" BLayout = RowMajor"
;
}
else
if
constexpr
(
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
)
{
std
::
cout
<<
" BLayout = ColumnMajor"
;
}
std
::
cout
<<
" M = "
<<
M
<<
" N = "
<<
N
<<
" K = "
<<
K
<<
" StrideA = "
<<
StrideA
<<
" StrideB = "
<<
StrideB
<<
" StrideC = "
<<
StrideC
<<
" KBatch = "
<<
best_kbatch
<<
" : "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
}
}
// namespace profiler
}
// namespace ck
profiler/include/profiler/profile_tall_and_skinny_gemm_splitk_impl.hpp
0 → 100755
View file @
49facb91
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include <iostream>
#include <typeinfo>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_tall_and_skinny_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/tall_and_skinny_gemm_splitk.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace
ck
{
namespace
profiler
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
bool
profile_tall_and_skinny_gemm_splitk_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
bool
time_kernel
,
int
M
,
int
N
,
int
K
,
int
StrideA
,
int
StrideB
,
int
StrideC
,
int
KBatch
)
{
bool
pass
=
true
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
using
namespace
ck
::
literals
;
if
(
is_same
<
decltype
(
layout
),
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_device_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
1
,
2
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
1
,
2
});
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
}
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceTsmm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
// Run reference GEMM
if
(
do_verification
)
{
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
}
std
::
string
best_op_name
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
float
best_kbatch
=
0
;
// profile device GEMM instances
for
(
auto
&
op_ptr
:
op_ptrs
)
{
std
::
vector
<
int
>
kbatch_list
=
{
1
,
2
,
4
,
8
,
12
,
16
,
20
,
24
,
32
,
36
,
40
,
60
,
64
,
72
,
80
,
88
,
96
,
128
,
144
,
160
,
176
,
192
,
256
};
if
(
KBatch
>
0
)
{
kbatch_list
=
{
KBatch
};
}
for
(
std
::
size_t
i
=
0
;
i
<
kbatch_list
.
size
();
i
++
)
{
auto
kbatch_curr
=
kbatch_list
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
kbatch_curr
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
// re-init C to zero before profiling next kernel
c_device_buf
.
SetZero
();
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
if
(
do_verification
)
{
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a_m_k
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b_k_n
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host : "
,
c_m_n_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
c_m_n_device_result
.
mData
,
","
)
<<
std
::
endl
;
}
}
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
", KBatch "
<<
kbatch_curr
<<
std
::
endl
;
#if defined CK_ENABLE_FP8
// set softer tolerances for fp8
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
||
is_same_v
<
BDataType
,
f8_t
>
||
is_same_v
<
CDataType
,
f8_t
>
)
{
std
::
string
msg
=
"Error: Incorrect results!"
;
double
rtol
=
1e-1
;
double
atol
=
1e-1
;
pass
=
pass
&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
msg
,
rtol
,
atol
);
}
else
{
#endif
pass
=
pass
&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
#if defined CK_ENABLE_FP8
}
#endif
if
(
tflops
>
best_tflops
)
{
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
best_kbatch
=
kbatch_curr
;
}
}
else
{
std
::
cout
<<
op_ptr
->
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
}
}
}
if
constexpr
(
is_same
<
CDataType
,
float
>::
value
)
{
std
::
cout
<<
"Best Perf for datatype = f32"
;
}
else
if
constexpr
(
is_same
<
CDataType
,
half_t
>::
value
)
{
std
::
cout
<<
"Best Perf for datatype = f16"
;
}
else
if
constexpr
(
is_same
<
CDataType
,
bhalf_t
>::
value
)
{
std
::
cout
<<
"Best Perf for datatype = bf16"
;
}
else
if
constexpr
(
is_same
<
CDataType
,
int8_t
>::
value
)
{
std
::
cout
<<
"Best Perf for datatype = int8"
;
}
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
std
::
cout
<<
" ALayout = RowMajor"
;
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
)
{
std
::
cout
<<
" ALayout = ColumnMajor"
;
}
if
constexpr
(
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
std
::
cout
<<
" BLayout = RowMajor"
;
}
else
if
constexpr
(
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
)
{
std
::
cout
<<
" BLayout = ColumnMajor"
;
}
std
::
cout
<<
" M = "
<<
M
<<
" N = "
<<
N
<<
" K = "
<<
K
<<
" StrideA = "
<<
StrideA
<<
" StrideB = "
<<
StrideB
<<
" StrideC = "
<<
StrideC
<<
" KBatch = "
<<
best_kbatch
<<
" : "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
}
}
// namespace profiler
}
// namespace ck
profiler/src/profile_gemv_splitk.cpp
0 → 100755
View file @
49facb91
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_gemv_splitk_impl.hpp"
#include "profiler_operation_registry.hpp"
enum
struct
GemmMatrixLayout
{
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
KM_KN_MN
,
// 2
KM_NK_MN
,
// 3
};
enum
struct
GemmDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
BF16_BF16_BF16
,
// 2
INT8_INT8_INT8
,
// 3
F8_F16_F16
,
// 4
F16_F8_F16
,
// 5
};
#define OP_NAME "gemv_splitk"
#define OP_DESC "Split-K GEMM"
int
profile_gemv_splitk
(
int
argc
,
char
*
argv
[])
{
if
(
argc
!=
15
)
{
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8)
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 3: A[k, m] * B[n, k] = C[m, n])
\n
"
);
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg6: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg8 to 13: M, N, K, StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg14: split k into mulitiple batch
\n
"
);
exit
(
1
);
}
const
auto
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
7
]);
const
int
M
=
std
::
stoi
(
argv
[
8
]);
const
int
N
=
std
::
stoi
(
argv
[
9
]);
const
int
K
=
std
::
stoi
(
argv
[
10
]);
const
int
StrideA
=
std
::
stoi
(
argv
[
11
]);
const
int
StrideB
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
const
int
KBatch
=
std
::
stoi
(
argv
[
14
]);
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
// #if defined CK_ENABLE_FP8
// using F8 = ck::f8_t;
// #endif
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
auto
profile
=
[
&
](
auto
a_type
,
auto
b_type
,
auto
acc_type
,
auto
c_type
,
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
using
ADataType
=
decltype
(
a_type
);
using
BDataType
=
decltype
(
b_type
);
using
AccDataType
=
decltype
(
acc_type
);
using
CDataType
=
decltype
(
c_type
);
using
ALayout
=
decltype
(
a_layout
);
using
BLayout
=
decltype
(
b_layout
);
using
CLayout
=
decltype
(
c_layout
);
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideB
=
ck
::
is_same_v
<
BLayout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideC
=
ck
::
is_same_v
<
CLayout
,
Row
>
?
N
:
M
;
bool
pass
=
ck
::
profiler
::
profile_gemv_splitk_impl
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
DefaultStrideA
:
StrideA
,
(
StrideB
<
0
)
?
DefaultStrideB
:
StrideB
,
(
StrideC
<
0
)
?
DefaultStrideC
:
StrideC
,
KBatch
);
return
pass
?
0
:
1
;
};
// if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
// {
// return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{});
// }
// else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
// {
// return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{});
// }
// else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
// {
// return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{});
// }
// else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
// {
// return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{});
// }
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
}
// else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
// {
// return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{});
// }
// else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
// {
// return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
// }
else
{
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
return
1
;
}
}
REGISTER_PROFILER_OPERATION
(
OP_NAME
,
OP_DESC
,
profile_gemv_splitk
);
profiler/src/profile_tall_and_skinny_gemm_splitk.cpp
0 → 100755
View file @
49facb91
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_tall_and_skinny_gemm_splitk_impl.hpp"
#include "profiler_operation_registry.hpp"
enum
struct
GemmMatrixLayout
{
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
KM_KN_MN
,
// 2
KM_NK_MN
,
// 3
};
enum
struct
GemmDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
BF16_BF16_BF16
,
// 2
INT8_INT8_INT8
,
// 3
F8_F16_F16
,
// 4
F16_F8_F16
,
// 5
};
#define OP_NAME "tall_and_skinny_gemm_splitk"
#define OP_DESC "Tall and Skinny GEMM splitk"
int
profile_tall_and_skinny_gemm_splitk
(
int
argc
,
char
*
argv
[])
{
if
(
argc
!=
15
)
{
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8)
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 3: A[k, m] * B[n, k] = C[m, n])
\n
"
);
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg6: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg8 to 13: M, N, K, StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg14: split k into mulitiple batch
\n
"
);
exit
(
1
);
}
const
auto
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
7
]);
const
int
M
=
std
::
stoi
(
argv
[
8
]);
const
int
N
=
std
::
stoi
(
argv
[
9
]);
const
int
K
=
std
::
stoi
(
argv
[
10
]);
const
int
StrideA
=
std
::
stoi
(
argv
[
11
]);
const
int
StrideB
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
const
int
KBatch
=
std
::
stoi
(
argv
[
14
]);
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
// #if defined CK_ENABLE_FP8
// using F8 = ck::f8_t;
// #endif
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
auto
profile
=
[
&
](
auto
a_type
,
auto
b_type
,
auto
acc_type
,
auto
c_type
,
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
using
ADataType
=
decltype
(
a_type
);
using
BDataType
=
decltype
(
b_type
);
using
AccDataType
=
decltype
(
acc_type
);
using
CDataType
=
decltype
(
c_type
);
using
ALayout
=
decltype
(
a_layout
);
using
BLayout
=
decltype
(
b_layout
);
using
CLayout
=
decltype
(
c_layout
);
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideB
=
ck
::
is_same_v
<
BLayout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideC
=
ck
::
is_same_v
<
CLayout
,
Row
>
?
N
:
M
;
bool
pass
=
ck
::
profiler
::
profile_tall_and_skinny_gemm_splitk_impl
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
DefaultStrideA
:
StrideA
,
(
StrideB
<
0
)
?
DefaultStrideB
:
StrideB
,
(
StrideC
<
0
)
?
DefaultStrideC
:
StrideC
,
KBatch
);
return
pass
?
0
:
1
;
};
// if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
// {
// return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{});
// }
// else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
// {
// return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{});
// }
// else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
// {
// return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{});
// }
// else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
// {
// return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{});
// }
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
}
// else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
// {
// return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{});
// }
// else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
// {
// return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
// }
else
{
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
return
1
;
}
}
REGISTER_PROFILER_OPERATION
(
OP_NAME
,
OP_DESC
,
profile_tall_and_skinny_gemm_splitk
);
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment