Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ce23d145
Commit
ce23d145
authored
Jul 18, 2023
by
danyao12
Browse files
Merge branch 'develop' into mha-train-develop
parents
1128cd3a
189ea3b9
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
56 additions
and
22 deletions
+56
-22
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp
...ce/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp
+2
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp
...m/device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp
+2
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
.../device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
+2
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
.../device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
+2
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
.../device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
+2
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
.../device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
+2
-1
library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt
...tensor_operation_instance/gpu/quantization/CMakeLists.txt
+2
-0
profiler/src/profile_batched_gemm_multi_d.cpp
profiler/src/profile_batched_gemm_multi_d.cpp
+5
-1
profiler/src/profile_conv_bwd_data.cpp
profiler/src/profile_conv_bwd_data.cpp
+8
-0
profiler/src/profile_gemm.cpp
profiler/src/profile_gemm.cpp
+11
-3
test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp
test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp
+4
-2
test/gemm/CMakeLists.txt
test/gemm/CMakeLists.txt
+14
-10
No files found.
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp
View file @
ce23d145
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -80,3 +80,4 @@ void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(
...
@@ -80,3 +80,4 @@ void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp
View file @
ce23d145
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -78,3 +78,4 @@ void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(
...
@@ -78,3 +78,4 @@ void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
View file @
ce23d145
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -66,3 +66,4 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
...
@@ -66,3 +66,4 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
View file @
ce23d145
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -66,3 +66,4 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
...
@@ -66,3 +66,4 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
View file @
ce23d145
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -66,3 +66,4 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
...
@@ -66,3 +66,4 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
View file @
ce23d145
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -63,3 +63,4 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
...
@@ -63,3 +63,4 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt
View file @
ce23d145
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
set
(
CONV2D_PERLAYER_QUANT_SRC
set
(
CONV2D_PERLAYER_QUANT_SRC
conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp
conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp
conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp
conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp
...
@@ -36,3 +37,4 @@ add_instance_library(device_quantization_instance
...
@@ -36,3 +37,4 @@ add_instance_library(device_quantization_instance
${
CONV2D_BIAS_PERCHANNEL_QUANT_SRC
}
${
CONV2D_BIAS_PERCHANNEL_QUANT_SRC
}
${
GEMM_QUANT_SRC
}
${
GEMM_QUANT_SRC
}
)
)
endif
()
\ No newline at end of file
profiler/src/profile_batched_gemm_multi_d.cpp
View file @
ce23d145
...
@@ -70,8 +70,10 @@ int profile_batched_gemm_multi_d(int argc, char* argv[])
...
@@ -70,8 +70,10 @@ int profile_batched_gemm_multi_d(int argc, char* argv[])
const
int
BatchCount
=
std
::
stoi
(
argv
[
17
]);
const
int
BatchCount
=
std
::
stoi
(
argv
[
17
]);
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
#ifdef __int8__
using
INT8
=
int8_t
;
using
INT8
=
int8_t
;
#endif
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
@@ -163,6 +165,7 @@ int profile_batched_gemm_multi_d(int argc, char* argv[])
...
@@ -163,6 +165,7 @@ int profile_batched_gemm_multi_d(int argc, char* argv[])
{
{
return
profile
(
F16
{},
F16
{},
F16
{},
Col
{},
Col
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F16
{},
Col
{},
Col
{},
Row
{});
}
}
#ifdef __int8__
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
return
profile
(
INT8
{},
INT8
{},
INT8
{},
Row
{},
Row
{},
Row
{});
return
profile
(
INT8
{},
INT8
{},
INT8
{},
Row
{},
Row
{},
Row
{});
...
@@ -179,6 +182,7 @@ int profile_batched_gemm_multi_d(int argc, char* argv[])
...
@@ -179,6 +182,7 @@ int profile_batched_gemm_multi_d(int argc, char* argv[])
{
{
return
profile
(
INT8
{},
INT8
{},
INT8
{},
Col
{},
Col
{},
Row
{});
return
profile
(
INT8
{},
INT8
{},
INT8
{},
Col
{},
Col
{},
Row
{});
}
}
#endif
else
else
{
{
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
...
...
profiler/src/profile_conv_bwd_data.cpp
View file @
ce23d145
...
@@ -77,7 +77,9 @@ int profile_conv_bwd_data(int argc, char* argv[])
...
@@ -77,7 +77,9 @@ int profile_conv_bwd_data(int argc, char* argv[])
using
F32
=
float
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
#ifdef __int8__
using
INT8
=
int8_t
;
using
INT8
=
int8_t
;
#endif
using
NWC
=
ck
::
tensor_layout
::
convolution
::
NWC
;
using
NWC
=
ck
::
tensor_layout
::
convolution
::
NWC
;
using
NHWC
=
ck
::
tensor_layout
::
convolution
::
NHWC
;
using
NHWC
=
ck
::
tensor_layout
::
convolution
::
NHWC
;
...
@@ -138,10 +140,12 @@ int profile_conv_bwd_data(int argc, char* argv[])
...
@@ -138,10 +140,12 @@ int profile_conv_bwd_data(int argc, char* argv[])
{
{
return
profile
(
I1
,
NWC
{},
KXC
{},
NWK
{},
BF16
{},
BF16
{},
BF16
{});
return
profile
(
I1
,
NWC
{},
KXC
{},
NWK
{},
BF16
{},
BF16
{},
BF16
{});
}
}
#ifdef __int8__
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
{
{
return
profile
(
I1
,
NWC
{},
KXC
{},
NWK
{},
INT8
{},
INT8
{},
INT8
{});
return
profile
(
I1
,
NWC
{},
KXC
{},
NWK
{},
INT8
{},
INT8
{},
INT8
{});
}
}
#endif
}
}
else
if
(
num_dim_spatial
==
2
&&
layout
==
ConvLayout
::
NHWC_KYXC_NHWK
)
else
if
(
num_dim_spatial
==
2
&&
layout
==
ConvLayout
::
NHWC_KYXC_NHWK
)
{
{
...
@@ -157,10 +161,12 @@ int profile_conv_bwd_data(int argc, char* argv[])
...
@@ -157,10 +161,12 @@ int profile_conv_bwd_data(int argc, char* argv[])
{
{
return
profile
(
I2
,
NHWC
{},
KYXC
{},
NHWK
{},
BF16
{},
BF16
{},
BF16
{});
return
profile
(
I2
,
NHWC
{},
KYXC
{},
NHWK
{},
BF16
{},
BF16
{},
BF16
{});
}
}
#ifdef __int8__
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
{
{
return
profile
(
I2
,
NHWC
{},
KYXC
{},
NHWK
{},
INT8
{},
INT8
{},
INT8
{});
return
profile
(
I2
,
NHWC
{},
KYXC
{},
NHWK
{},
INT8
{},
INT8
{},
INT8
{});
}
}
#endif
}
}
else
if
(
num_dim_spatial
==
3
&&
layout
==
ConvLayout
::
NHWC_KYXC_NHWK
)
else
if
(
num_dim_spatial
==
3
&&
layout
==
ConvLayout
::
NHWC_KYXC_NHWK
)
{
{
...
@@ -176,10 +182,12 @@ int profile_conv_bwd_data(int argc, char* argv[])
...
@@ -176,10 +182,12 @@ int profile_conv_bwd_data(int argc, char* argv[])
{
{
return
profile
(
I3
,
NDHWC
{},
KZYXC
{},
NDHWK
{},
BF16
{},
BF16
{},
BF16
{});
return
profile
(
I3
,
NDHWC
{},
KZYXC
{},
NDHWK
{},
BF16
{},
BF16
{},
BF16
{});
}
}
#ifdef __int8__
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
{
{
return
profile
(
I3
,
NDHWC
{},
KZYXC
{},
NDHWK
{},
INT8
{},
INT8
{},
INT8
{});
return
profile
(
I3
,
NDHWC
{},
KZYXC
{},
NDHWK
{},
INT8
{},
INT8
{},
INT8
{});
}
}
#endif
}
}
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
...
...
profiler/src/profile_gemm.cpp
View file @
ce23d145
...
@@ -67,11 +67,15 @@ int profile_gemm(int argc, char* argv[])
...
@@ -67,11 +67,15 @@ int profile_gemm(int argc, char* argv[])
const
int
StrideB
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideB
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
using
F32
=
float
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
#ifdef __bf16__
using
BF16
=
ck
::
bhalf_t
;
#endif
#ifdef __int8__
using
INT8
=
int8_t
;
using
INT8
=
int8_t
;
using
INT32
=
int32_t
;
using
INT32
=
int32_t
;
#endif
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
@@ -149,6 +153,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -149,6 +153,7 @@ int profile_gemm(int argc, char* argv[])
{
{
return
profile
(
Col
{},
Col
{},
Row
{},
F16
{},
F16
{},
F32
{},
F16
{});
return
profile
(
Col
{},
Col
{},
Row
{},
F16
{},
F16
{},
F32
{},
F16
{});
}
}
#ifdef __bf16__
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
return
profile
(
Row
{},
Row
{},
Row
{},
BF16
{},
BF16
{},
F32
{},
BF16
{});
return
profile
(
Row
{},
Row
{},
Row
{},
BF16
{},
BF16
{},
F32
{},
BF16
{});
...
@@ -165,6 +170,8 @@ int profile_gemm(int argc, char* argv[])
...
@@ -165,6 +170,8 @@ int profile_gemm(int argc, char* argv[])
{
{
return
profile
(
Col
{},
Col
{},
Row
{},
BF16
{},
BF16
{},
F32
{},
BF16
{});
return
profile
(
Col
{},
Col
{},
Row
{},
BF16
{},
BF16
{},
F32
{},
BF16
{});
}
}
#endif
#ifdef __int8__
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
return
profile
(
Row
{},
Row
{},
Row
{},
INT8
{},
INT8
{},
INT32
{},
INT8
{});
return
profile
(
Row
{},
Row
{},
Row
{},
INT8
{},
INT8
{},
INT32
{},
INT8
{});
...
@@ -181,6 +188,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -181,6 +188,7 @@ int profile_gemm(int argc, char* argv[])
{
{
return
profile
(
Col
{},
Col
{},
Row
{},
INT8
{},
INT8
{},
INT32
{},
INT8
{});
return
profile
(
Col
{},
Col
{},
Row
{},
INT8
{},
INT8
{},
INT32
{},
INT8
{});
}
}
#endif
else
else
{
{
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
...
...
test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp
View file @
ce23d145
...
@@ -68,7 +68,9 @@ using KernelTypes = ::testing::Types<std::tuple<Row, Row, Row>,
...
@@ -68,7 +68,9 @@ using KernelTypes = ::testing::Types<std::tuple<Row, Row, Row>,
}
// namespace
}
// namespace
TYPED_TEST_SUITE
(
TestBatchedGemmMultiD
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestBatchedGemmMultiD
,
KernelTypes
);
#ifdef __fp16
TYPED_TEST
(
TestBatchedGemmMultiD
,
f16
)
{
this
->
template
Run
<
F16
>();
}
TYPED_TEST
(
TestBatchedGemmMultiD
,
f16
)
{
this
->
template
Run
<
F16
>();
}
#endif
#ifdef __int8__
TYPED_TEST
(
TestBatchedGemmMultiD
,
int8
)
{
this
->
template
Run
<
int8_t
>();
}
TYPED_TEST
(
TestBatchedGemmMultiD
,
int8
)
{
this
->
template
Run
<
int8_t
>();
}
#endif
test/gemm/CMakeLists.txt
View file @
ce23d145
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_test_executable
(
test_gemm_fp32 gemm_fp32.cpp
)
add_test_executable
(
test_gemm_fp32 gemm_fp32.cpp
)
target_link_libraries
(
test_gemm_fp32 PRIVATE utility
)
target_link_libraries
(
test_gemm_fp32 PRIVATE utility
)
target_link_libraries
(
test_gemm_fp32 PRIVATE device_gemm_instance
)
target_link_libraries
(
test_gemm_fp32 PRIVATE device_gemm_instance
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_test_executable
(
test_gemm_fp16 gemm_fp16.cpp
)
add_test_executable
(
test_gemm_fp16 gemm_fp16.cpp
)
target_link_libraries
(
test_gemm_fp16 PRIVATE utility
)
target_link_libraries
(
test_gemm_fp16 PRIVATE utility
)
target_link_libraries
(
test_gemm_fp16 PRIVATE device_gemm_instance
)
target_link_libraries
(
test_gemm_fp16 PRIVATE device_gemm_instance
)
add_test_executable
(
test_gemm_bf16 gemm_bf16.cpp
)
target_link_libraries
(
test_gemm_bf16 PRIVATE utility
)
target_link_libraries
(
test_gemm_bf16 PRIVATE device_gemm_instance
)
add_test_executable
(
test_gemm_int8 gemm_int8.cpp
)
target_link_libraries
(
test_gemm_int8 PRIVATE utility
)
target_link_libraries
(
test_gemm_int8 PRIVATE device_gemm_instance
)
add_library
(
gemm_standalone_xdl_fp16_instances STATIC
add_library
(
gemm_standalone_xdl_fp16_instances STATIC
instance/gemm_f16_nn_instance.cpp
instance/gemm_f16_nn_instance.cpp
instance/gemm_f16_nt_instance.cpp
instance/gemm_f16_nt_instance.cpp
...
@@ -24,3 +17,14 @@ add_library(gemm_standalone_xdl_fp16_instances STATIC
...
@@ -24,3 +17,14 @@ add_library(gemm_standalone_xdl_fp16_instances STATIC
add_test_executable
(
test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp
)
add_test_executable
(
test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp
)
target_link_libraries
(
test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility
)
target_link_libraries
(
test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility
)
target_include_directories
(
test_gemm_standalone_xdl_fp16 PRIVATE instance/
)
target_include_directories
(
test_gemm_standalone_xdl_fp16 PRIVATE instance/
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_test_executable
(
test_gemm_bf16 gemm_bf16.cpp
)
target_link_libraries
(
test_gemm_bf16 PRIVATE utility
)
target_link_libraries
(
test_gemm_bf16 PRIVATE device_gemm_instance
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_test_executable
(
test_gemm_int8 gemm_int8.cpp
)
target_link_libraries
(
test_gemm_int8 PRIVATE utility
)
target_link_libraries
(
test_gemm_int8 PRIVATE device_gemm_instance
)
endif
()
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment