Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6af2b468
Commit
6af2b468
authored
Aug 25, 2022
by
Adam Osewski
Browse files
CGEMM int4 example.
parent
6ceb900b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
278 additions
and
85 deletions
+278
-85
example/22_cgemm/CMakeLists.txt
example/22_cgemm/CMakeLists.txt
+10
-4
example/22_cgemm/cgemm_xdl_bf16.cpp
example/22_cgemm/cgemm_xdl_bf16.cpp
+11
-11
example/22_cgemm/cgemm_xdl_common.hpp
example/22_cgemm/cgemm_xdl_common.hpp
+84
-37
example/22_cgemm/cgemm_xdl_fp16.cpp
example/22_cgemm/cgemm_xdl_fp16.cpp
+11
-11
example/22_cgemm/cgemm_xdl_fp32.cpp
example/22_cgemm/cgemm_xdl_fp32.cpp
+11
-11
example/22_cgemm/cgemm_xdl_int4.cpp
example/22_cgemm/cgemm_xdl_int4.cpp
+140
-0
example/22_cgemm/cgemm_xdl_int8.cpp
example/22_cgemm/cgemm_xdl_int8.cpp
+11
-11
No files found.
example/22_cgemm/CMakeLists.txt
View file @
6af2b468
...
...
@@ -5,7 +5,13 @@ add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp)
add_example_executable
(
example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp
)
add_example_executable
(
example_cgemm_xdl_int8 cgemm_xdl_int8.cpp
)
add_dependencies
(
example_cgemm_xdl example_cgemm_xdl_bf16
)
add_dependencies
(
example_cgemm_xdl example_cgemm_xdl_fp16
)
add_dependencies
(
example_cgemm_xdl example_cgemm_xdl_fp32
)
add_dependencies
(
example_cgemm_xdl example_cgemm_xdl_int8
)
add_dependencies
(
example_cgemm_xdl
example_cgemm_xdl_bf16
example_cgemm_xdl_fp16
example_cgemm_xdl_fp32
example_cgemm_xdl_int8
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_cgemm_xdl_int4 cgemm_xdl_int4.cpp
)
add_dependencies
(
example_cgemm_xdl example_cgemm_xdl_int4
)
endif
()
example/22_cgemm/cgemm_xdl_bf16.cpp
View file @
6af2b468
...
...
@@ -117,16 +117,16 @@ int main(int argc, char* argv[])
exit
(
0
);
}
return
run_cgemm_xdl
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceCGemmInstance
,
ReferenceCGemmInstance
>
(
return
!
run_cgemm_xdl
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceCGemmInstance
,
ReferenceCGemmInstance
>
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
do_verification
,
init_method
,
time_kernel
);
}
example/22_cgemm/cgemm_xdl_common.hpp
View file @
6af2b468
...
...
@@ -21,6 +21,9 @@ using F32 = float;
using
BF16
=
ck
::
bhalf_t
;
using
INT8
=
std
::
int8_t
;
using
INT32
=
std
::
int32_t
;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
INT4
=
ck
::
int4_t
;
#endif
template
<
typename
ADataType
,
typename
BDataType
,
...
...
@@ -32,17 +35,31 @@ template <typename ADataType,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DeviceCGemmInstance
,
typename
ReferenceCGemmInstance
>
int
run_cgemm_xdl
(
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
bool
do_verification
,
int
init_method
,
bool
time_kernel
)
typename
ReferenceCGemmInstance
,
typename
KernelADataType
=
ADataType
,
typename
KernelBDataType
=
BDataType
,
typename
KernelCDataType
=
CDataType
>
bool
run_cgemm_xdl
(
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
bool
do_verification
,
int
init_method
,
bool
time_kernel
)
{
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static_assert
(
sizeof
(
ck
::
int4_t
)
==
sizeof
(
int8_t
),
"sizeof ck::int4_t and int8_t is different!"
);
static_assert
(
sizeof
(
ADataType
)
==
sizeof
(
KernelADataType
),
"sizeof ADataType and KernelADataType is different!"
);
static_assert
(
sizeof
(
BDataType
)
==
sizeof
(
KernelBDataType
),
"sizeof BDataType and KernelBDataType is different!"
);
static_assert
(
sizeof
(
CDataType
)
==
sizeof
(
KernelCDataType
),
"sizeof CDataType and KernelCDataType is different!"
);
#endif
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
...
...
@@ -61,8 +78,10 @@ int run_cgemm_xdl(ck::index_t M,
Tensor
<
ADataType
>
a_m_k_imag
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n_real
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n_imag
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_real_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_imag_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
KernelCDataType
>
c_m_n_real_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
KernelCDataType
>
c_m_n_imag_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k_real: "
<<
a_m_k_real
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_m_k_imag: "
<<
a_m_k_imag
.
mDesc
<<
std
::
endl
;
...
...
@@ -89,16 +108,35 @@ int run_cgemm_xdl(ck::index_t M,
auto
cgemm
=
DeviceCGemmInstance
{};
DeviceMem
a_m_k_real_device_buf
(
sizeof
(
ADataType
)
*
a_m_k_real
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_m_k_imag_device_buf
(
sizeof
(
ADataType
)
*
a_m_k_imag
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_real_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_real
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_imag_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_imag
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_real_device_buf
(
sizeof
(
CDataType
)
*
DeviceMem
a_m_k_real_device_buf
(
sizeof
(
KernelADataType
)
*
a_m_k_real
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_m_k_imag_device_buf
(
sizeof
(
KernelADataType
)
*
a_m_k_imag
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_real_device_buf
(
sizeof
(
KernelBDataType
)
*
b_k_n_real
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_imag_device_buf
(
sizeof
(
KernelBDataType
)
*
b_k_n_imag
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_real_device_buf
(
sizeof
(
KernelCDataType
)
*
c_m_n_real_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_imag_device_buf
(
sizeof
(
CDataType
)
*
DeviceMem
c_m_n_imag_device_buf
(
sizeof
(
Kernel
CDataType
)
*
c_m_n_imag_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
workspace_device_buf
(
cgemm
.
GetWorkspaceSize
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
));
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
if
constexpr
(
std
::
is_same_v
<
ADataType
,
ck
::
int4_t
>
)
{
Tensor
<
KernelADataType
>
a_m_k_real_converted
(
a_m_k_real
);
Tensor
<
KernelADataType
>
a_m_k_imag_converted
(
a_m_k_imag
);
Tensor
<
KernelBDataType
>
b_k_n_real_converted
(
b_k_n_real
);
Tensor
<
KernelBDataType
>
b_k_n_imag_converted
(
b_k_n_imag
);
a_m_k_real
=
a_m_k_real_converted
;
a_m_k_imag
=
a_m_k_imag_converted
;
b_k_n_real
=
b_k_n_real_converted
;
b_k_n_imag
=
b_k_n_imag_converted
;
}
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
a_m_k_real_device_buf
.
ToDevice
(
a_m_k_real
.
mData
.
data
());
a_m_k_imag_device_buf
.
ToDevice
(
a_m_k_imag
.
mData
.
data
());
b_k_n_real_device_buf
.
ToDevice
(
b_k_n_real
.
mData
.
data
());
...
...
@@ -111,13 +149,13 @@ int run_cgemm_xdl(ck::index_t M,
// do GEMM
auto
invoker
=
cgemm
.
MakeInvoker
();
auto
argument
=
cgemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_m_k_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ADataType
*>
(
a_m_k_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
workspace_device_buf
.
GetDeviceBuffer
()),
cgemm
.
MakeArgument
(
static_cast
<
Kernel
ADataType
*>
(
a_m_k_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Kernel
ADataType
*>
(
a_m_k_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Kernel
BDataType
*>
(
b_k_n_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Kernel
BDataType
*>
(
b_k_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Kernel
CDataType
*>
(
c_m_n_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Kernel
CDataType
*>
(
c_m_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Kernel
CDataType
*>
(
workspace_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
...
...
@@ -142,16 +180,12 @@ int run_cgemm_xdl(ck::index_t M,
std
::
size_t
(
2
)
*
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
cgemm
.
GetTypeString
()
<<
std
::
endl
;
c_m_n_real_device_buf
.
FromDevice
(
c_m_n_real_device_result
.
mData
.
data
());
c_m_n_imag_device_buf
.
FromDevice
(
c_m_n_imag_device_result
.
mData
.
data
());
if
(
do_verification
)
{
Tensor
<
CDataType
>
c_m_n_real_host_result
(
...
...
@@ -159,9 +193,8 @@ int run_cgemm_xdl(ck::index_t M,
Tensor
<
CDataType
>
c_m_n_imag_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
auto
ref_cgemm
=
ReferenceCGemmInstance
{};
auto
ref_invoker
=
ref_cgemm
.
MakeInvoker
();
auto
ref_cgemm
=
ReferenceCGemmInstance
{};
auto
ref_invoker
=
ref_cgemm
.
MakeInvoker
();
auto
ref_argument
=
ref_cgemm
.
MakeArgument
(
a_m_k_real
,
a_m_k_imag
,
b_k_n_real
,
...
...
@@ -174,19 +207,33 @@ int run_cgemm_xdl(ck::index_t M,
ref_invoker
.
Run
(
ref_argument
);
c_m_n_real_device_buf
.
FromDevice
(
c_m_n_real_device_result
.
mData
.
data
());
c_m_n_imag_device_buf
.
FromDevice
(
c_m_n_imag_device_result
.
mData
.
data
());
bool
result
=
true
;
result
=
ck
::
utils
::
check_err
(
c_m_n_real_device_result
.
mData
,
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
if
constexpr
(
std
::
is_same_v
<
ADataType
,
ck
::
int4_t
>
)
{
const
Tensor
<
CDataType
>
c_m_n_real_device_result_converted
(
c_m_n_real_device_result
);
const
Tensor
<
CDataType
>
c_m_n_imag_device_result_converted
(
c_m_n_imag_device_result
);
c_m_n_real_device_result
=
c_m_n_real_device_result_converted
;
c_m_n_imag_device_result
=
c_m_n_imag_device_result_converted
;
}
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
result
=
ck
::
utils
::
check_err
(
c_m_n_real_device_result
.
mData
,
c_m_n_real_host_result
.
mData
,
"Verification error: incorrect results in real part!"
,
1e-2
f
,
1e-1
f
);
result
=
result
&&
result
=
result
&&
ck
::
utils
::
check_err
(
c_m_n_imag_device_result
.
mData
,
c_m_n_imag_host_result
.
mData
,
"Verification error: incorrect results in imaginary part!"
,
1e-2
f
,
1e-1
f
);
return
result
?
0
:
1
;
return
result
;
}
return
0
;
return
true
;
}
example/22_cgemm/cgemm_xdl_fp16.cpp
View file @
6af2b468
...
...
@@ -116,16 +116,16 @@ int main(int argc, char* argv[])
exit
(
0
);
}
return
run_cgemm_xdl
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceCGemmInstance
,
ReferenceCGemmInstance
>
(
return
!
run_cgemm_xdl
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceCGemmInstance
,
ReferenceCGemmInstance
>
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
do_verification
,
init_method
,
time_kernel
);
}
example/22_cgemm/cgemm_xdl_fp32.cpp
View file @
6af2b468
...
...
@@ -117,16 +117,16 @@ int main(int argc, char* argv[])
exit
(
0
);
}
return
run_cgemm_xdl
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceCGemmInstance
,
ReferenceCGemmInstance
>
(
return
!
run_cgemm_xdl
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceCGemmInstance
,
ReferenceCGemmInstance
>
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
do_verification
,
init_method
,
time_kernel
);
}
example/22_cgemm/cgemm_xdl_int4.cpp
0 → 100644
View file @
6af2b468
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "cgemm_xdl_common.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
using
ADataType
=
INT4
;
using
BDataType
=
INT4
;
using
CDataType
=
INT4
;
using
AccDataType
=
INT32
;
using
CShuffleDataType
=
INT32
;
using
KernelADataType
=
INT8
;
using
KernelBDataType
=
INT8
;
using
KernelCDataType
=
INT8
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
ReferenceCGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceCGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
// clang-format off
using
DeviceCGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceCGemm_4Gemm_Xdl_CShuffle
<
ALayout
,
// typename ALayout
BLayout
,
// typename BLayout
CLayout
,
// typename CLayout
KernelADataType
,
// typename ADataType
KernelBDataType
,
// typename BDataType
KernelCDataType
,
// typename CDataType
AccDataType
,
// typename GemmAccDataType
CShuffleDataType
,
// typename CShuffleDataType
PassThrough
,
// typename AElementwiseOperation
PassThrough
,
// typename BElementwiseOperation
PassThrough
,
// typename CElementwiseOperation
GemmDefault
,
// GemmSpecialization GemmSpec
1
,
// index_t NumGemmKPrefetchStage
256
,
// index_t BlockSize
256
,
// index_t MPerBlock
128
,
// index_t NPerBlock
64
,
// index_t KPerBlock
16
,
// index_t AK1
16
,
// index_t BK1
32
,
// index_t MPerXDL
32
,
// index_t NPerXDL
4
,
// index_t MXdlPerWave
2
,
// index_t NXdlPerWave
S
<
4
,
64
,
1
>
,
// typename ABlockTransferThreadClusterLengths_AK0_M_AK1
S
<
1
,
0
,
2
>
,
// typename ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// typename ABlockTransferSrcAccessOrder
2
,
// index_t ABlockTransferSrcVectorDim
16
,
// index_t ABlockTransferSrcScalarPerVector
16
,
// index_t ABlockTransferDstScalarPerVector_AK1
1
,
// index_t ABlockLdsExtraM
S
<
4
,
64
,
1
>
,
// typename BBlockTransferThreadClusterLengths_BK0_N_BK1
S
<
1
,
0
,
2
>
,
// typename BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// typename BBlockTransferSrcAccessOrder
2
,
// index_t BBlockTransferSrcVectorDim
8
,
// index_t BBlockTransferSrcScalarPerVector
8
,
// index_t BBlockTransferDstScalarPerVector_BK1
1
,
// index_t BBlockLdsExtraN
1
,
// index_t CShuffleMXdlPerWavePerShuffle
1
,
// index_t CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
16
>
;
// index_t CShuffleBlockTransferScalarPerVector_NPerBlock
// clang-format on
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
true
;
// CGEMM shape
ck
::
index_t
M
=
1024
;
ck
::
index_t
N
=
1152
;
ck
::
index_t
K
=
512
;
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideB
=
K
;
ck
::
index_t
StrideC
=
N
;
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
10
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
StrideA
=
std
::
stoi
(
argv
[
7
]);
StrideB
=
std
::
stoi
(
argv
[
8
]);
StrideC
=
std
::
stoi
(
argv
[
9
]);
}
else
{
std
::
cout
<<
"arg1: verification (0=no, 1=yes)
\n
"
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
<<
"arg3: time kernel (0=no, 1=yes)
\n
"
<<
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
\n
"
<<
std
::
endl
;
exit
(
EXIT_SUCCESS
);
}
return
!
run_cgemm_xdl
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceCGemmInstance
,
ReferenceCGemmInstance
,
KernelADataType
,
KernelBDataType
,
KernelCDataType
>
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
do_verification
,
init_method
,
time_kernel
);
}
example/22_cgemm/cgemm_xdl_int8.cpp
View file @
6af2b468
...
...
@@ -117,16 +117,16 @@ int main(int argc, char* argv[])
exit
(
0
);
}
return
run_cgemm_xdl
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceCGemmInstance
,
ReferenceCGemmInstance
>
(
return
!
run_cgemm_xdl
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceCGemmInstance
,
ReferenceCGemmInstance
>
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
do_verification
,
init_method
,
time_kernel
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment