Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
39cfca6f
Commit
39cfca6f
authored
Apr 19, 2022
by
j4yan
Browse files
add gemm_dlops_f16
parent
3ac4aea4
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
915 additions
and
16 deletions
+915
-16
include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp
include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp
+1
-1
include/ck/utility/inner_product.hpp
include/ck/utility/inner_product.hpp
+36
-0
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
...ary/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
+8
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_kn_mn_instance.cpp
.../gemm/device_gemm_dlops_f16_f16_f16_km_kn_mn_instance.cpp
+69
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_nk_mn_instance.cpp
.../gemm/device_gemm_dlops_f16_f16_f16_km_nk_mn_instance.cpp
+70
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp
.../gemm/device_gemm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp
+69
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp
.../gemm/device_gemm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp
+70
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp
.../gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp
+1
-2
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp
.../gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp
.../gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp
.../gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp
+1
-2
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp
...mm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp
+79
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp
...mm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp
+80
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp
...mm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp
+79
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp
...mm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp
+80
-0
test/gemm_dlops/CMakeLists.txt
test/gemm_dlops/CMakeLists.txt
+9
-9
test/gemm_dlops/gemm_dlops_fp16.cpp
test/gemm_dlops/gemm_dlops_fp16.cpp
+130
-0
test/gemm_dlops/gemm_dlops_int8.cpp
test/gemm_dlops/gemm_dlops_int8.cpp
+131
-0
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp
View file @
39cfca6f
...
@@ -21,8 +21,8 @@ namespace device {
...
@@ -21,8 +21,8 @@ namespace device {
template
<
template
<
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
,
typename
CLayout
,
...
...
include/ck/utility/inner_product.hpp
View file @
39cfca6f
...
@@ -70,6 +70,12 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
...
@@ -70,6 +70,12 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
c
);
c
);
}
}
template
<
>
__device__
void
inner_product
<
half_t
,
half_t
,
float
>
(
const
half_t
&
a
,
const
half_t
&
b
,
float
&
c
)
{
c
+=
a
*
b
;
}
template
<
>
template
<
>
__device__
void
inner_product
<
half2_t
,
half2_t
,
float
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
)
__device__
void
inner_product
<
half2_t
,
half2_t
,
float
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
)
{
{
...
@@ -134,6 +140,36 @@ __device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const h
...
@@ -134,6 +140,36 @@ __device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const h
c
);
c
);
}
}
template
<
>
__device__
void
inner_product
<
int8_t
,
int8_t
,
int32_t
>
(
const
int8_t
&
a
,
const
int8_t
&
b
,
int32_t
&
c
)
{
c
+=
a
*
b
;
}
template
<
>
__device__
void
inner_product
<
int8x2_t
,
int8x2_t
,
int32_t
>
(
const
int8x2_t
&
a
,
const
int8x2_t
&
b
,
int32_t
&
c
)
{
// #if defined(CK_USE_DOT2_I32_I8)
// #if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
// asm volatile("\n \
// v_dot2_i32_i8 %0, %1, %2, %0\n \
// "
// : "=v"(c)
// : "v"(bit_cast<int32_t>(a)), "v"(bit_cast<int32_t>(b)), "0"(c));
// #else
// c = __builtin_amdgcn_sdot2(bit_cast<int32_t>(a), bit_cast<int32_t>(b), c, false);
// #endif
// #else
const
vector_type
<
int8_t
,
2
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
2
>
b_vector
{
b
};
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
c
+=
type_convert
<
int32_t
>
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
type_convert
<
int32_t
>
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
});
// #endif
}
template
<
>
template
<
>
__device__
void
__device__
void
inner_product
<
int8x4_t
,
int8x4_t
,
int32_t
>
(
const
int8x4_t
&
a
,
const
int8x4_t
&
b
,
int32_t
&
c
)
inner_product
<
int8x4_t
,
int8x4_t
,
int32_t
>
(
const
int8x4_t
&
a
,
const
int8x4_t
&
b
,
int32_t
&
c
)
...
...
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
View file @
39cfca6f
...
@@ -49,6 +49,14 @@ set(DEVICE_GEMM_DLOPS_INSTANCE_SOURCE
...
@@ -49,6 +49,14 @@ set(DEVICE_GEMM_DLOPS_INSTANCE_SOURCE
device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp;
device_gemm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_dlops_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_dlops_f16_f16_f16_km_nk_mn_instance.cpp;
device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp;
device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp;
device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp;
device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp;
)
)
add_library
(
device_gemm_dlops_instance SHARED
${
DEVICE_GEMM_DLOPS_INSTANCE_SOURCE
}
)
add_library
(
device_gemm_dlops_instance SHARED
${
DEVICE_GEMM_DLOPS_INSTANCE_SOURCE
}
)
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_kn_mn_instance.cpp
0 → 100644
View file @
39cfca6f
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_dlops.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using
device_gemm_dlops_f16_f16_f16_km_kn_mn_instances
=
std
::
tuple
<
// clang-format off
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
/*
* K1 = 1
*/
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
8
,
1
,
8
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
8
,
1
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
8
,
1
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
1
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
1
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
4
,
2
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
128
,
8
,
1
,
8
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
64
,
8
,
1
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
128
,
8
,
1
,
4
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
/*
* K1 = 2
*/
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
8
,
2
,
8
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
8
,
2
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
8
,
2
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
2
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
2
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
4
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
128
,
8
,
2
,
8
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
64
,
8
,
2
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
128
,
8
,
2
,
4
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
// clang-format on
>
;
void
add_device_gemm_dlops_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_dlops_f16_f16_f16_km_kn_mn_instances
{});
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_km_nk_mn_instance.cpp
0 → 100644
View file @
39cfca6f
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_dlops.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using
device_gemm_dlops_f16_f16_f16_km_nk_mn_instances
=
std
::
tuple
<
// clang-format off
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
/*
* K1 = 1
*/
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
8
,
1
,
8
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
8
,
1
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
8
,
1
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
1
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
1
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
4
,
2
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
128
,
8
,
1
,
8
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
64
,
8
,
1
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
128
,
8
,
1
,
4
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
/*
* K1 = 2
*/
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
8
,
2
,
8
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
8
,
2
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
8
,
2
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
2
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
2
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
4
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
128
,
8
,
2
,
8
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
64
,
8
,
2
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
128
,
8
,
2
,
4
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
// clang-format on
>
;
void
add_device_gemm_dlops_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_dlops_f16_f16_f16_km_nk_mn_instances
{});
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_kn_mn_instance.cpp
0 → 100644
View file @
39cfca6f
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_dlops.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances
=
std
::
tuple
<
// clang-format off
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
/*
* K1 = 1
*/
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
8
,
1
,
8
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
8
,
1
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
8
,
1
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
1
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
1
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
4
,
2
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
128
,
8
,
1
,
8
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
64
,
8
,
1
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
128
,
8
,
1
,
4
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
/*
* K1 = 2
*/
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
8
,
2
,
8
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
8
,
2
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
8
,
2
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
2
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
2
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
4
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
128
,
8
,
2
,
8
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
64
,
8
,
2
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
128
,
8
,
2
,
4
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
// clang-format on
>
;
void
add_device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances
{});
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f16_f16_f16_mk_nk_mn_instance.cpp
0 → 100644
View file @
39cfca6f
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_dlops.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using
device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances
=
std
::
tuple
<
// clang-format off
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
/*
* K1 = 1
*/
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
8
,
1
,
8
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
8
,
1
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
8
,
1
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
1
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
1
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
4
,
2
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
128
,
8
,
1
,
8
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
64
,
8
,
1
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
128
,
8
,
1
,
4
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
/*
* K1 = 2
*/
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
8
,
2
,
8
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
8
,
2
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
8
,
2
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
2
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
2
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
4
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
128
,
8
,
2
,
8
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
64
,
8
,
2
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
128
,
8
,
2
,
4
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
// clang-format on
>
;
void
add_device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances
{});
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp
View file @
39cfca6f
...
@@ -24,11 +24,10 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
...
@@ -24,11 +24,10 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using
device_gemm_dlops_f32_f32_f32_km_kn_mn_instances
=
std
::
tuple
<
using
device_gemm_dlops_f32_f32_f32_km_kn_mn_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_
M
0_
M
1_K1| K0_
M
0_
M
1_K1| ArrangeOrder| Order| Lengths_K0_
M
0_
M
1_K1| ContiguousDimOrder| Lengths_K0_
M
0_
M
1_K1| Order| | |
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_
N
0_
N
1_K1| K0_
N
0_
N
1_K1| ArrangeOrder| Order| Lengths_K0_
N
0_
N
1_K1| ContiguousDimOrder| Lengths_K0_
N
0_
N
1_K1| Order| | |
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
/*
/*
* K1 = 1
* K1 = 1
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_nk_mn_instance.cpp
View file @
39cfca6f
...
@@ -27,7 +27,7 @@ using device_gemm_dlops_f32_f32_f32_km_nk_mn_instances = std::tuple<
...
@@ -27,7 +27,7 @@ using device_gemm_dlops_f32_f32_f32_km_nk_mn_instances = std::tuple<
// clang-format off
// clang-format off
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_
M
0_
M
1_K1| K0_
M
0_
M
1_K1| ArrangeOrder| Order| Lengths_K0_
M
0_
M
1_K1| ContiguousDimOrder| Lengths_K0_
M
0_
M
1_K1| Order| | |
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_
N
0_
N
1_K1| K0_
N
0_
N
1_K1| ArrangeOrder| Order| Lengths_K0_
N
0_
N
1_K1| ContiguousDimOrder| Lengths_K0_
N
0_
N
1_K1| Order| | |
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
/*
/*
* K1 = 1
* K1 = 1
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_kn_mn_instance.cpp
View file @
39cfca6f
...
@@ -27,7 +27,7 @@ using device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances = std::tuple<
...
@@ -27,7 +27,7 @@ using device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances = std::tuple<
// clang-format off
// clang-format off
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_
M
0_
M
1_K1| K0_
M
0_
M
1_K1| ArrangeOrder| Order| Lengths_K0_
M
0_
M
1_K1| ContiguousDimOrder| Lengths_K0_
M
0_
M
1_K1| Order| | |
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_
N
0_
N
1_K1| K0_
N
0_
N
1_K1| ArrangeOrder| Order| Lengths_K0_
N
0_
N
1_K1| ContiguousDimOrder| Lengths_K0_
N
0_
N
1_K1| Order| | |
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
/*
/*
* K1 = 1
* K1 = 1
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_mk_nk_mn_instance.cpp
View file @
39cfca6f
...
@@ -27,7 +27,7 @@ using device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances = std::tuple<
...
@@ -27,7 +27,7 @@ using device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances = std::tuple<
// clang-format off
// clang-format off
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_
M
0_
M
1_K1| K0_
M
0_
M
1_K1| ArrangeOrder| Order| Lengths_K0_
M
0_
M
1_K1| ContiguousDimOrder| Lengths_K0_
M
0_
M
1_K1| Order| | |
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_
N
0_
N
1_K1| K0_
N
0_
N
1_K1| ArrangeOrder| Order| Lengths_K0_
N
0_
N
1_K1| ContiguousDimOrder| Lengths_K0_
N
0_
N
1_K1| Order| | |
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
/*
/*
* K1 = 1
* K1 = 1
...
@@ -52,7 +52,6 @@ using device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances = std::tuple<
...
@@ -52,7 +52,6 @@ using device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances = std::tuple<
DeviceGemmDlops
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
128
,
8
,
2
,
8
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
128
,
8
,
2
,
8
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
64
,
8
,
2
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
64
,
8
,
2
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
128
,
8
,
2
,
4
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
DeviceGemmDlops
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
128
,
8
,
2
,
4
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
// DeviceGemmDlops< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 4, 1, S<8, 1>, S<8, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on
// clang-format on
>
;
>
;
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_kn_mn_instance.cpp
0 → 100644
View file @
39cfca6f
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_dlops.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using
device_gemm_dlops_int8_int8_int8_km_kn_mn_instances
=
std
::
tuple
<
// clang-format off
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
/*
* K1 = 1
*/
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
8
,
1
,
8
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
8
,
1
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
8
,
1
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
1
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
1
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
4
,
2
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
128
,
8
,
1
,
8
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
64
,
8
,
1
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
128
,
8
,
1
,
4
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
2
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
4
,
1
,
4
,
1
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
/*
* K1 = 2
*/
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
8
,
2
,
8
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
8
,
2
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
8
,
2
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
2
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
2
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
4
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
128
,
8
,
2
,
8
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
64
,
8
,
2
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
128
,
8
,
2
,
4
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
2
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
/*
* K1 = 2
*/
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
8
,
4
,
8
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
2
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
8
,
4
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
2
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
8
,
4
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
4
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
4
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
8
,
1
,
1
,
4
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
4
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
4
,
2
>
,
S
<
8
,
1
,
1
,
4
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
8
,
1
,
1
,
4
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
128
,
8
,
4
,
8
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
4
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
64
,
8
,
4
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
2
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
128
,
8
,
4
,
4
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
2
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
4
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
// clang-format on
>
;
void
add_device_gemm_dlops_int8_int8_int8_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_dlops_int8_int8_int8_km_kn_mn_instances
{});
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_km_nk_mn_instance.cpp
0 → 100644
View file @
39cfca6f
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_dlops.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using
device_gemm_dlops_int8_int8_int8_km_nk_mn_instances
=
std
::
tuple
<
// clang-format off
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
/*
* K1 = 1
*/
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
//
/*
* K1 = 2
*/
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
/*
* K1 = 4
*/
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
8
,
4
,
8
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
2
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
8
,
4
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
2
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
8
,
4
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
4
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
4
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
8
,
1
,
1
,
4
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
4
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
4
,
2
>
,
S
<
8
,
1
,
1
,
4
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
8
,
1
,
1
,
4
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
128
,
8
,
4
,
8
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
4
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
64
,
8
,
4
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
2
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
128
,
8
,
4
,
4
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
2
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
4
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
// clang-format on
>
;
void
add_device_gemm_dlops_int8_int8_int8_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_dlops_int8_int8_int8_km_nk_mn_instances
{});
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_kn_mn_instance.cpp
0 → 100644
View file @
39cfca6f
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_dlops.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances
=
std
::
tuple
<
// clang-format off
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
/*
* K1 = 1
*/
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
//
/*
* K1 = 2
*/
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
/*
* K1 = 4
*/
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
8
,
4
,
8
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
2
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
8
,
4
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
2
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
8
,
4
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
4
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
4
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
8
,
1
,
1
,
4
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
4
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
4
,
2
>
,
S
<
8
,
1
,
1
,
4
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
8
,
1
,
1
,
4
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
128
,
8
,
4
,
8
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
4
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
64
,
8
,
4
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
2
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
128
,
8
,
4
,
4
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
2
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
4
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
// clang-format on
>
;
void
add_device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances
{});
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_int8_int8_int8_mk_nk_mn_instance.cpp
0 → 100644
View file @
39cfca6f
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_dlops.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using
device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances
=
std
::
tuple
<
// clang-format off
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
/*
* K1 = 1
*/
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 1, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 1, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 1, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 1, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 1, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 1, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<4, 1, 4, 1>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
//
/*
* K1 = 2
*/
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 8, 2, 8, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 8, 2, 4, 8, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 8, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 8, 2, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<1, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 128, 8, 2, 8, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 128, 64, 8, 2, 8, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// DeviceGemmDlops< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 128, 8, 2, 4, 8, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>,
/*
* K1 = 4
*/
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
8
,
4
,
8
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
2
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
8
,
4
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
2
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
8
,
4
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
4
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
4
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
8
,
1
,
1
,
4
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
8
,
4
,
4
,
8
,
1
,
S
<
8
,
2
>
,
S
<
4
,
2
>
,
S
<
8
,
1
,
1
,
4
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
8
,
1
,
1
,
4
>
,
S
<
1
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
128
,
8
,
4
,
8
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
4
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
128
,
64
,
8
,
4
,
8
,
4
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
4
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
2
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
,
DeviceGemmDlops
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
128
,
8
,
4
,
4
,
8
,
1
,
S
<
4
,
2
>
,
S
<
4
,
2
>
,
S
<
4
,
1
,
2
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
4
,
1
,
4
,
4
>
,
S
<
2
,
1
,
32
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
// clang-format on
>
;
void
add_device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances
{});
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/gemm_dlops/CMakeLists.txt
View file @
39cfca6f
...
@@ -2,14 +2,14 @@ add_test_executable(test_gemm_dlops_fp32 gemm_dlops_fp32.cpp)
...
@@ -2,14 +2,14 @@ add_test_executable(test_gemm_dlops_fp32 gemm_dlops_fp32.cpp)
target_link_libraries
(
test_gemm_dlops_fp32 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_dlops_fp32 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_dlops_fp32 PRIVATE device_gemm_dlops_instance
)
target_link_libraries
(
test_gemm_dlops_fp32 PRIVATE device_gemm_dlops_instance
)
#
add_test_executable(test_gemm_dlops_fp16 gemm_fp16.cpp)
add_test_executable
(
test_gemm_dlops_fp16 gemm_
dlops_
fp16.cpp
)
#
target_link_libraries(test_gemm_dlops_fp16 PRIVATE host_tensor)
target_link_libraries
(
test_gemm_dlops_fp16 PRIVATE host_tensor
)
#
target_link_libraries(test_gemm_dlops_fp16 PRIVATE device_gemm_dlops_instance)
target_link_libraries
(
test_gemm_dlops_fp16 PRIVATE device_gemm_dlops_instance
)
#
# add_test_executable(test_gemm_dlops_bf16 gemm_bf16.cpp)
# add_test_executable(test_gemm_dlops_bf16 gemm_
dlops_
bf16.cpp)
# target_link_libraries(test_gemm_dlops_bf16 PRIVATE host_tensor)
# target_link_libraries(test_gemm_dlops_bf16 PRIVATE host_tensor)
# target_link_libraries(test_gemm_dlops_bf16 PRIVATE device_gemm_dlops_instance)
# target_link_libraries(test_gemm_dlops_bf16 PRIVATE device_gemm_dlops_instance)
#
#
add_test_executable(test_gemm_dlops_int8 gemm_int8.cpp)
add_test_executable
(
test_gemm_dlops_int8 gemm_
dlops_
int8.cpp
)
#
target_link_libraries(test_gemm_dlops_int8 PRIVATE host_tensor)
target_link_libraries
(
test_gemm_dlops_int8 PRIVATE host_tensor
)
#
target_link_libraries(test_gemm_dlops_int8 PRIVATE device_gemm_dlops_instance)
target_link_libraries
(
test_gemm_dlops_int8 PRIVATE device_gemm_dlops_instance
)
test/gemm_dlops/gemm_dlops_fp16.cpp
0 → 100644
View file @
39cfca6f
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "../gemm/gemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_dlops.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
void
add_device_gemm_dlops_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dlops_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
int
main
()
{
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
bool
res
=
true
;
std
::
vector
<
DeviceGemmNoOpPtr
>
gemmPtrs
;
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dlops_f16_f16_f16_km_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
ColumnMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dlops_f16_f16_f16_km_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dlops_f16_f16_f16_mk_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dlops_f16_f16_f16_mk_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
RowMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
std
::
cout
<<
"TestGemm ..... "
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
res
?
0
:
1
;
}
test/gemm_dlops/gemm_dlops_int8.cpp
0 → 100644
View file @
39cfca6f
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "../gemm/gemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_dlops.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
void
add_device_gemm_dlops_int8_int8_int8_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dlops_int8_int8_int8_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
int
main
()
{
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
bool
res
=
true
;
std
::
vector
<
DeviceGemmNoOpPtr
>
gemmPtrs
;
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dlops_int8_int8_int8_km_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
ColumnMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dlops_int8_int8_int8_km_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dlops_int8_int8_int8_mk_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dlops_int8_int8_int8_mk_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
RowMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
std
::
cout
<<
"TestGemm ..... "
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
res
?
0
:
1
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment