Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6c8ca54b
Commit
6c8ca54b
authored
Apr 20, 2022
by
ltqin
Browse files
gemm reference add double data type
parent
bf5af9f9
Changes
19
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
116 additions
and
45 deletions
+116
-45
example/01_gemm/gemm_xdl_bf16.cpp
example/01_gemm/gemm_xdl_bf16.cpp
+1
-1
example/01_gemm/gemm_xdl_fp16.cpp
example/01_gemm/gemm_xdl_fp16.cpp
+1
-1
example/01_gemm/gemm_xdl_fp64.cpp
example/01_gemm/gemm_xdl_fp64.cpp
+1
-1
example/01_gemm/gemm_xdl_int8.cpp
example/01_gemm/gemm_xdl_int8.cpp
+7
-2
example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp
example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp
+7
-2
example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
...quant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
+7
-2
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
+1
-1
example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp
example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+12
-11
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+5
-4
profiler/include/profile_gemm_impl.hpp
profiler/include/profile_gemm_impl.hpp
+10
-2
profiler/include/profile_gemm_reduce_impl.hpp
profiler/include/profile_gemm_reduce_impl.hpp
+7
-2
profiler/include/profile_grouped_gemm_impl.hpp
profiler/include/profile_grouped_gemm_impl.hpp
+2
-0
profiler/src/profile_gemm.cpp
profiler/src/profile_gemm.cpp
+16
-0
test/gemm/gemm_fp16.cpp
test/gemm/gemm_fp16.cpp
+8
-3
test/gemm/gemm_fp32.cpp
test/gemm/gemm_fp32.cpp
+10
-5
test/gemm/gemm_int8.cpp
test/gemm/gemm_int8.cpp
+10
-5
test/gemm/gemm_util.hpp
test/gemm/gemm_util.hpp
+3
-0
test/grouped_gemm/grouped_gemm_fp16.cpp
test/grouped_gemm/grouped_gemm_fp16.cpp
+7
-2
No files found.
example/01_gemm/gemm_xdl_bf16.cpp
View file @
6c8ca54b
...
@@ -81,7 +81,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
...
@@ -81,7 +81,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
float
,
float
,
float
,
PassThrough
,
PassThrough
,
PassThrough
>
;
ReferenceGemm
<
float
,
float
,
float
,
float
,
PassThrough
,
PassThrough
,
PassThrough
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
...
example/01_gemm/gemm_xdl_fp16.cpp
View file @
6c8ca54b
...
@@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
...
@@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
...
example/01_gemm/gemm_xdl_fp64.cpp
View file @
6c8ca54b
...
@@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
...
@@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
...
example/01_gemm/gemm_xdl_int8.cpp
View file @
6c8ca54b
...
@@ -80,8 +80,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
...
@@ -80,8 +80,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
4
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
4
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
BDataType
,
CDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
...
example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp
View file @
6c8ca54b
...
@@ -72,8 +72,13 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device::
...
@@ -72,8 +72,13 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device::
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
// clang-format on
using
ReferenceConvBwdWeightInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceConvBwdWeightInstance
=
ReferenceConvBwdWeight
<
InDataType
,
WeiDataType
,
OutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
>
;
ck
::
tensor_operation
::
host
::
ReferenceConvBwdWeight
<
InDataType
,
WeiDataType
,
OutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
...
example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
View file @
6c8ca54b
...
@@ -83,8 +83,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
...
@@ -83,8 +83,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
16
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
16
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
RequantReluRequant
>
;
BDataType
,
CDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
RequantReluRequant
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
View file @
6c8ca54b
...
@@ -56,7 +56,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdl
...
@@ -56,7 +56,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdl
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
...
example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp
View file @
6c8ca54b
...
@@ -52,7 +52,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_
...
@@ -52,7 +52,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
6c8ca54b
...
@@ -389,10 +389,10 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
...
@@ -389,10 +389,10 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
{
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
//group_size * num_groups_per_blk;
static
constexpr
index_t
num_regs_per_blk
=
4
;
//
group_size * num_groups_per_blk;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
//wave_size / num_threads_per_blk;
static
constexpr
index_t
num_input_blks
=
4
;
//
wave_size / num_threads_per_blk;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
...
@@ -690,8 +690,9 @@ struct XdlopsGemm
...
@@ -690,8 +690,9 @@ struct XdlopsGemm
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
{
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
,
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
,
"base base_type must be double, float, half, bfloat16, and int8_t!"
);
"base base_type must be double, float, half, bfloat16, and int8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
6c8ca54b
...
@@ -13,6 +13,7 @@ namespace host {
...
@@ -13,6 +13,7 @@ namespace host {
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
CDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
>
...
@@ -55,12 +56,12 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -55,12 +56,12 @@ struct ReferenceGemm : public device::BaseOperator
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
const
int
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
float
v_acc
=
0
;
AccDataType
v_acc
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
float
v_a
;
AccDataType
v_a
;
float
v_b
;
AccDataType
v_b
;
arg
.
a_element_op_
(
v_a
,
static_cast
<
const
float
>
(
arg
.
a_m_k_
(
m
,
k
)));
arg
.
a_element_op_
(
v_a
,
static_cast
<
const
float
>
(
arg
.
a_m_k_
(
m
,
k
)));
arg
.
b_element_op_
(
v_b
,
static_cast
<
const
float
>
(
arg
.
b_k_n_
(
k
,
n
)));
arg
.
b_element_op_
(
v_b
,
static_cast
<
const
float
>
(
arg
.
b_k_n_
(
k
,
n
)));
...
@@ -68,7 +69,7 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -68,7 +69,7 @@ struct ReferenceGemm : public device::BaseOperator
v_acc
+=
v_a
*
v_b
;
v_acc
+=
v_a
*
v_b
;
}
}
float
v_c
;
AccDataType
v_c
;
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_element_op_
(
v_c
,
v_acc
);
...
...
profiler/include/profile_gemm_impl.hpp
View file @
6c8ca54b
...
@@ -85,6 +85,7 @@ namespace profiler {
...
@@ -85,6 +85,7 @@ namespace profiler {
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
CDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
>
typename
CLayout
>
...
@@ -457,8 +458,14 @@ void profile_gemm_impl(int do_verification,
...
@@ -457,8 +458,14 @@ void profile_gemm_impl(int do_verification,
bf16_to_f32_
(
b_k_n
,
b_f32_k_n
);
bf16_to_f32_
(
b_k_n
,
b_f32_k_n
);
bf16_to_f32_
(
c_m_n_device_result
,
c_m_n_device_f32_result
);
bf16_to_f32_
(
c_m_n_device_result
,
c_m_n_device_f32_result
);
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ReferenceGemm
<
float
,
float
,
float
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
float
,
float
,
float
,
float
,
AElementOp
,
BElementOp
,
CElementOp
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
@@ -490,6 +497,7 @@ void profile_gemm_impl(int do_verification,
...
@@ -490,6 +497,7 @@ void profile_gemm_impl(int do_verification,
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
CElementOp
>
;
CElementOp
>
;
...
...
profiler/include/profile_gemm_reduce_impl.hpp
View file @
6c8ca54b
...
@@ -127,8 +127,13 @@ bool profile_gemm_reduce_impl(int do_verification,
...
@@ -127,8 +127,13 @@ bool profile_gemm_reduce_impl(int do_verification,
if
(
do_verification
)
if
(
do_verification
)
{
{
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
BDataType
,
CDataType
,
DDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
...
profiler/include/profile_grouped_gemm_impl.hpp
View file @
6c8ca54b
...
@@ -43,6 +43,7 @@ namespace profiler {
...
@@ -43,6 +43,7 @@ namespace profiler {
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
CDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
>
typename
CLayout
>
...
@@ -270,6 +271,7 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -270,6 +271,7 @@ void profile_grouped_gemm_impl(int do_verification,
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
CElementOp
>
;
CElementOp
>
;
...
...
profiler/src/profile_gemm.cpp
View file @
6c8ca54b
...
@@ -68,6 +68,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -68,6 +68,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -88,6 +89,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -88,6 +89,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -108,6 +110,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -108,6 +110,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -128,6 +131,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -128,6 +131,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -146,6 +150,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -146,6 +150,7 @@ int profile_gemm(int argc, char* argv[])
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
ck
::
profiler
::
profile_gemm_impl
<
float
,
ck
::
profiler
::
profile_gemm_impl
<
float
,
float
,
float
,
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
...
@@ -166,6 +171,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -166,6 +171,7 @@ int profile_gemm(int argc, char* argv[])
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
ck
::
profiler
::
profile_gemm_impl
<
float
,
ck
::
profiler
::
profile_gemm_impl
<
float
,
float
,
float
,
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
...
@@ -186,6 +192,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -186,6 +192,7 @@ int profile_gemm(int argc, char* argv[])
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
{
ck
::
profiler
::
profile_gemm_impl
<
float
,
ck
::
profiler
::
profile_gemm_impl
<
float
,
float
,
float
,
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
...
@@ -206,6 +213,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -206,6 +213,7 @@ int profile_gemm(int argc, char* argv[])
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
{
ck
::
profiler
::
profile_gemm_impl
<
float
,
ck
::
profiler
::
profile_gemm_impl
<
float
,
float
,
float
,
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
...
@@ -228,6 +236,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -228,6 +236,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
int32_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -248,6 +257,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -248,6 +257,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
int32_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -268,6 +278,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -268,6 +278,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
int32_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -288,6 +299,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -288,6 +299,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
int32_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -308,6 +320,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -308,6 +320,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -328,6 +341,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -328,6 +341,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -348,6 +362,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -348,6 +362,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -368,6 +383,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -368,6 +383,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
...
test/gemm/gemm_fp16.cpp
View file @
6c8ca54b
...
@@ -55,6 +55,7 @@ int main()
...
@@ -55,6 +55,7 @@ int main()
using
ADataType
=
ck
::
half_t
;
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
@@ -74,6 +75,7 @@ int main()
...
@@ -74,6 +75,7 @@ int main()
ADataType
,
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
RowMajor
,
RowMajor
,
RowMajor
,
...
@@ -96,6 +98,7 @@ int main()
...
@@ -96,6 +98,7 @@ int main()
ADataType
,
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
RowMajor
,
...
@@ -118,6 +121,7 @@ int main()
...
@@ -118,6 +121,7 @@ int main()
ADataType
,
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
RowMajor
,
RowMajor
,
RowMajor
,
...
@@ -142,6 +146,7 @@ int main()
...
@@ -142,6 +146,7 @@ int main()
ADataType
,
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
RowMajor
,
...
...
test/gemm/gemm_fp32.cpp
View file @
6c8ca54b
...
@@ -56,6 +56,7 @@ int main()
...
@@ -56,6 +56,7 @@ int main()
using
ADataType
=
float
;
using
ADataType
=
float
;
using
BDataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
CDataType
=
float
;
using
AccDataType
=
float
,
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
@@ -75,6 +76,7 @@ int main()
...
@@ -75,6 +76,7 @@ int main()
ADataType
,
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
RowMajor
,
RowMajor
,
RowMajor
,
...
@@ -97,6 +99,7 @@ int main()
...
@@ -97,6 +99,7 @@ int main()
ADataType
,
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
RowMajor
,
...
@@ -119,6 +122,7 @@ int main()
...
@@ -119,6 +122,7 @@ int main()
ADataType
,
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
RowMajor
,
RowMajor
,
RowMajor
,
...
@@ -141,6 +145,7 @@ int main()
...
@@ -141,6 +145,7 @@ int main()
ADataType
,
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
RowMajor
,
...
...
test/gemm/gemm_int8.cpp
View file @
6c8ca54b
...
@@ -49,6 +49,7 @@ int main()
...
@@ -49,6 +49,7 @@ int main()
using
ADataType
=
int8_t
;
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
AccDataType
=
int32_t
,
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
@@ -65,6 +66,7 @@ int main()
...
@@ -65,6 +66,7 @@ int main()
ADataType
,
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
RowMajor
,
RowMajor
,
RowMajor
,
...
@@ -83,6 +85,7 @@ int main()
...
@@ -83,6 +85,7 @@ int main()
ADataType
,
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
RowMajor
,
...
@@ -101,6 +104,7 @@ int main()
...
@@ -101,6 +104,7 @@ int main()
ADataType
,
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
RowMajor
,
RowMajor
,
RowMajor
,
...
@@ -119,6 +123,7 @@ int main()
...
@@ -119,6 +123,7 @@ int main()
ADataType
,
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
RowMajor
,
...
...
test/gemm/gemm_util.hpp
View file @
6c8ca54b
...
@@ -106,6 +106,7 @@ template <typename DeviceGemmPtr_,
...
@@ -106,6 +106,7 @@ template <typename DeviceGemmPtr_,
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
CDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
,
typename
CLayout
,
...
@@ -188,6 +189,7 @@ struct TestGemm
...
@@ -188,6 +189,7 @@ struct TestGemm
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
CElementwiseOperation
>
;
...
@@ -306,6 +308,7 @@ struct TestGemmBF16
...
@@ -306,6 +308,7 @@ struct TestGemmBF16
// use fp32 host kernel to verify bf16 device kernel
// use fp32 host kernel to verify bf16 device kernel
using
ReferenceGemmInstance
=
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
float
,
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
float
,
float
,
float
,
float
,
float
,
float
,
AElementwiseOperation
,
AElementwiseOperation
,
...
...
test/grouped_gemm/grouped_gemm_fp16.cpp
View file @
6c8ca54b
...
@@ -151,8 +151,13 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
...
@@ -151,8 +151,13 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
{
{
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
());
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
());
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
BDataType
,
CDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment