Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f6138c40
Commit
f6138c40
authored
Feb 26, 2022
by
rocking
Browse files
Add int8 of mk_nk_mn to the ckProfiler
parent
e221d11e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
107 additions
and
16 deletions
+107
-16
device_operation/CMakeLists.txt
device_operation/CMakeLists.txt
+13
-12
device_operation/src/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp
...e_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp
+56
-0
profiler/include/profile_gemm_impl.hpp
profiler/include/profile_gemm_impl.hpp
+13
-0
profiler/src/profile_gemm.cpp
profiler/src/profile_gemm.cpp
+24
-3
profiler/src/profile_gemm_bias_2d.cpp
profiler/src/profile_gemm_bias_2d.cpp
+1
-1
No files found.
device_operation/CMakeLists.txt
View file @
f6138c40
...
@@ -22,6 +22,7 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
...
@@ -22,6 +22,7 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp;
...
@@ -82,9 +83,9 @@ set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE
...
@@ -82,9 +83,9 @@ set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE
)
)
# device_conv1d_fwd_instance
# device_conv1d_fwd_instance
set
(
DEVICE_CONV1D_FWD_INSTANCE_SOURCE
set
(
DEVICE_CONV1D_FWD_INSTANCE_SOURCE
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp;
)
)
# device_conv2d_fwd_bias_relu_instance
# device_conv2d_fwd_bias_relu_instance
set
(
DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE
set
(
DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE
...
@@ -106,11 +107,11 @@ add_library(device_gemm_bias_2d_instance SHARED ${DEVICE_GEMM_BIAS_2D_INSTANCE_S
...
@@ -106,11 +107,11 @@ add_library(device_gemm_bias_2d_instance SHARED ${DEVICE_GEMM_BIAS_2D_INSTANCE_S
add_library
(
device_gemm_bias_relu_instance SHARED
${
DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE
}
)
add_library
(
device_gemm_bias_relu_instance SHARED
${
DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE
}
)
add_library
(
device_gemm_bias_relu_add_instance SHARED
${
DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE
}
)
add_library
(
device_gemm_bias_relu_add_instance SHARED
${
DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE
}
)
add_library
(
device_batched_gemm_instance SHARED
${
DEVICE_BATCHED_GEMM_INSTANCE_SOURCE
}
)
add_library
(
device_batched_gemm_instance SHARED
${
DEVICE_BATCHED_GEMM_INSTANCE_SOURCE
}
)
add_library
(
device_conv1d_fwd_instance SHARED
${
DEVICE_CONV1D_FWD_INSTANCE_SOURCE
}
)
add_library
(
device_conv1d_fwd_instance SHARED
${
DEVICE_CONV1D_FWD_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_instance SHARED
${
DEVICE_CONV2D_FWD_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_instance SHARED
${
DEVICE_CONV2D_FWD_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_bias_relu_instance SHARED
${
DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_bias_relu_instance SHARED
${
DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_bias_relu_add_instance SHARED
${
DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_bias_relu_add_instance SHARED
${
DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_bias_relu_atomic_add_instance SHARED
${
DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_bias_relu_atomic_add_instance SHARED
${
DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE
}
)
target_include_directories
(
device_gemm_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
target_include_directories
(
device_gemm_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
target_include_directories
(
device_gemm_bias_2d_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
target_include_directories
(
device_gemm_bias_2d_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
...
@@ -150,8 +151,8 @@ install(TARGETS device_gemm_bias_2d_instance LIBRARY DESTINATION lib)
...
@@ -150,8 +151,8 @@ install(TARGETS device_gemm_bias_2d_instance LIBRARY DESTINATION lib)
install
(
TARGETS device_gemm_bias_relu_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_gemm_bias_relu_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_gemm_bias_relu_add_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_gemm_bias_relu_add_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_batched_gemm_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_batched_gemm_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_conv1d_fwd_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_conv1d_fwd_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib
)
device_operation/src/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp
0 → 100644
View file @
f6138c40
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using
device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances
=
std
::
tuple
<
// clang-format off
//#####################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl_C_Shuffle
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceGemmXdl_C_Shuffle
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceGemmXdl_C_Shuffle
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceGemmXdl_C_Shuffle
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceGemmXdl_C_Shuffle
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
4
>
,
8
>
,
DeviceGemmXdl_C_Shuffle
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceGemmXdl_C_Shuffle
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
,
DeviceGemmXdl_C_Shuffle
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceGemmXdl_C_Shuffle
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
,
DeviceGemmXdl_C_Shuffle
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
4
>
,
8
>
,
DeviceGemmXdl_C_Shuffle
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
>
,
DeviceGemmXdl_C_Shuffle
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
,
DeviceGemmXdl_C_Shuffle
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
>
// clang-format on
>
;
void
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances
{});
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
profiler/include/profile_gemm_impl.hpp
View file @
f6138c40
...
@@ -31,6 +31,8 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<De
...
@@ -31,6 +31,8 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<De
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
(
void
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
...
@@ -290,6 +292,17 @@ void profile_gemm_impl(int do_verification,
...
@@ -290,6 +292,17 @@ void profile_gemm_impl(int do_verification,
}
}
}
}
}
}
else
if
constexpr
(
is_same
<
ADataType
,
int8_t
>::
value
&&
is_same
<
BDataType
,
int8_t
>::
value
&&
is_same
<
CDataType
,
int8_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances
(
gemm_ptrs
);
}
}
if
(
gemm_ptrs
.
size
()
<=
0
)
if
(
gemm_ptrs
.
size
()
<=
0
)
{
{
...
...
profiler/src/profile_gemm.cpp
View file @
f6138c40
...
@@ -20,8 +20,9 @@ enum GemmMatrixLayout
...
@@ -20,8 +20,9 @@ enum GemmMatrixLayout
enum
GemmDataType
enum
GemmDataType
{
{
F32_F32_F32
,
// 0
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
F16_F16_F16
,
// 1
INT8_INT8_INT8
,
// 2
};
};
int
profile_gemm
(
int
argc
,
char
*
argv
[])
int
profile_gemm
(
int
argc
,
char
*
argv
[])
...
@@ -29,7 +30,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -29,7 +30,7 @@ int profile_gemm(int argc, char* argv[])
if
(
!
(
argc
==
14
||
argc
==
15
))
if
(
!
(
argc
==
14
||
argc
==
15
))
{
{
printf
(
"arg1: tensor operation (gemm: GEMM)
\n
"
);
printf
(
"arg1: tensor operation (gemm: GEMM)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16
; 2: int8
)
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
...
@@ -221,6 +222,26 @@ int profile_gemm(int argc, char* argv[])
...
@@ -221,6 +222,26 @@ int profile_gemm(int argc, char* argv[])
(
StrideC
<
0
)
?
N
:
StrideC
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
KBatch
);
}
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
}
else
else
{
{
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
...
...
profiler/src/profile_gemm_bias_2d.cpp
View file @
f6138c40
...
@@ -28,7 +28,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
...
@@ -28,7 +28,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
{
{
if
(
!
(
argc
==
16
||
argc
==
17
))
if
(
!
(
argc
==
16
||
argc
==
17
))
{
{
printf
(
"arg1: tensor operation (gemm: GEMM+Bias)
\n
"
);
printf
(
"arg1: tensor operation (gemm: GEMM+Bias
_2d
)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16)
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment