Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f9c478e2
Commit
f9c478e2
authored
May 30, 2022
by
ltqin
Browse files
Merge branch 'develop' into bmatrix_skip_lds
parents
7d85d04a
91d8b7d6
Changes
347
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
482 additions
and
731 deletions
+482
-731
profiler/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp
...er/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp
+3
-2
profiler/include/profile_conv_fwd_bias_relu_impl.hpp
profiler/include/profile_conv_fwd_bias_relu_impl.hpp
+3
-2
profiler/include/profile_convnd_bwd_data_impl.hpp
profiler/include/profile_convnd_bwd_data_impl.hpp
+9
-8
profiler/include/profile_gemm_bias_2d_impl.hpp
profiler/include/profile_gemm_bias_2d_impl.hpp
+3
-2
profiler/include/profile_gemm_bias_relu_add_impl.hpp
profiler/include/profile_gemm_bias_relu_add_impl.hpp
+3
-2
profiler/include/profile_gemm_bias_relu_impl.hpp
profiler/include/profile_gemm_bias_relu_impl.hpp
+3
-2
profiler/include/profile_gemm_impl.hpp
profiler/include/profile_gemm_impl.hpp
+121
-21
profiler/include/profile_gemm_reduce_impl.hpp
profiler/include/profile_gemm_reduce_impl.hpp
+51
-45
profiler/include/profile_grouped_gemm_impl.hpp
profiler/include/profile_grouped_gemm_impl.hpp
+16
-13
profiler/include/profile_reduce_impl.hpp
profiler/include/profile_reduce_impl.hpp
+124
-309
profiler/src/profile_batched_gemm.cpp
profiler/src/profile_batched_gemm.cpp
+20
-20
profiler/src/profile_batched_gemm_reduce.cpp
profiler/src/profile_batched_gemm_reduce.cpp
+8
-8
profiler/src/profile_conv_bwd_data.cpp
profiler/src/profile_conv_bwd_data.cpp
+0
-195
profiler/src/profile_conv_bwd_weight.cpp
profiler/src/profile_conv_bwd_weight.cpp
+4
-4
profiler/src/profile_conv_fwd_bias_relu.cpp
profiler/src/profile_conv_fwd_bias_relu.cpp
+4
-4
profiler/src/profile_conv_fwd_bias_relu_add.cpp
profiler/src/profile_conv_fwd_bias_relu_add.cpp
+4
-4
profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp
profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp
+4
-4
profiler/src/profile_convnd_bwd_data.cpp
profiler/src/profile_convnd_bwd_data.cpp
+48
-48
profiler/src/profile_convnd_fwd.cpp
profiler/src/profile_convnd_fwd.cpp
+18
-18
profiler/src/profile_gemm.cpp
profiler/src/profile_gemm.cpp
+36
-20
No files found.
profiler/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp
View file @
f9c478e2
...
@@ -119,7 +119,7 @@ template <int NDimSpatial,
...
@@ -119,7 +119,7 @@ template <int NDimSpatial,
void
profile_conv_fwd_bias_relu_atomic_add_impl
(
int
do_verification
,
void
profile_conv_fwd_bias_relu_atomic_add_impl
(
int
do_verification
,
int
init_method
,
int
init_method
,
bool
do_log
,
bool
do_log
,
int
nrepeat
,
bool
time_kernel
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
...
@@ -275,7 +275,8 @@ void profile_conv_fwd_bias_relu_atomic_add_impl(int do_verification,
...
@@ -275,7 +275,8 @@ void profile_conv_fwd_bias_relu_atomic_add_impl(int do_verification,
{
{
std
::
string
conv_name
=
op_ptr
->
GetTypeString
();
std
::
string
conv_name
=
op_ptr
->
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
nrepeat
);
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
;
...
...
profiler/include/profile_conv_fwd_bias_relu_impl.hpp
View file @
f9c478e2
...
@@ -41,7 +41,7 @@ template <int NDimSpatial,
...
@@ -41,7 +41,7 @@ template <int NDimSpatial,
void
profile_conv_fwd_bias_relu_impl
(
int
do_verification
,
void
profile_conv_fwd_bias_relu_impl
(
int
do_verification
,
int
init_method
,
int
init_method
,
bool
do_log
,
bool
do_log
,
int
nrepeat
,
bool
time_kernel
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
...
@@ -207,7 +207,8 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
...
@@ -207,7 +207,8 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
{
{
std
::
string
conv_name
=
op_ptr
->
GetTypeString
();
std
::
string
conv_name
=
op_ptr
->
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
nrepeat
);
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
;
...
...
profiler/include/profile_convnd_bwd_data_impl.hpp
View file @
f9c478e2
#pragma once
#pragma once
#include "config.hpp"
#include "config.hpp"
#include "device.hpp"
#include "device.hpp"
#include "conv_
fwd_
util.hpp"
#include "conv_util.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_tensor_generator.hpp"
#include "tensor_layout.hpp"
#include "tensor_layout.hpp"
...
@@ -222,7 +222,7 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
...
@@ -222,7 +222,7 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
{
{
float
max_diff
=
1e-6
;
float
max_diff
=
1e-6
;
for
(
in
t
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
for
(
std
::
size_
t
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
{
{
float
diff
=
std
::
abs
(
double
(
ref
.
mData
[
i
])
-
double
(
result
.
mData
[
i
]));
float
diff
=
std
::
abs
(
double
(
ref
.
mData
[
i
])
-
double
(
result
.
mData
[
i
]));
if
(
max_diff
<
diff
)
if
(
max_diff
<
diff
)
...
@@ -236,16 +236,16 @@ template <typename DataType>
...
@@ -236,16 +236,16 @@ template <typename DataType>
void
show_data_nhwc_layout
(
Tensor
<
DataType
>&
nhwc
)
void
show_data_nhwc_layout
(
Tensor
<
DataType
>&
nhwc
)
{
{
std
::
cout
<<
"["
;
std
::
cout
<<
"["
;
for
(
int
n
=
0
;
n
<
nhwc
.
mDesc
.
GetLengths
()[
0
];
n
++
)
for
(
int
n
=
0
;
n
<
ck
::
type_convert
<
int
>
(
nhwc
.
mDesc
.
GetLengths
()[
0
]
)
;
n
++
)
{
{
std
::
cout
<<
"["
;
std
::
cout
<<
"["
;
for
(
int
hi
=
0
;
hi
<
nhwc
.
mDesc
.
GetLengths
()[
2
];
hi
++
)
for
(
int
hi
=
0
;
hi
<
ck
::
type_convert
<
int
>
(
nhwc
.
mDesc
.
GetLengths
()[
2
]
)
;
hi
++
)
{
{
std
::
cout
<<
"["
;
std
::
cout
<<
"["
;
for
(
int
wi
=
0
;
wi
<
nhwc
.
mDesc
.
GetLengths
()[
3
];
wi
++
)
for
(
int
wi
=
0
;
wi
<
ck
::
type_convert
<
int
>
(
nhwc
.
mDesc
.
GetLengths
()[
3
]
)
;
wi
++
)
{
{
std
::
cout
<<
"["
;
std
::
cout
<<
"["
;
for
(
int
c
=
0
;
c
<
nhwc
.
mDesc
.
GetLengths
()[
1
];
c
++
)
for
(
int
c
=
0
;
c
<
ck
::
type_convert
<
int
>
(
nhwc
.
mDesc
.
GetLengths
()[
1
]
)
;
c
++
)
{
{
std
::
cout
<<
static_cast
<
float
>
(
nhwc
(
n
,
c
,
hi
,
wi
))
<<
" "
;
std
::
cout
<<
static_cast
<
float
>
(
nhwc
(
n
,
c
,
hi
,
wi
))
<<
" "
;
}
}
...
@@ -269,7 +269,7 @@ template <int NDimSpatial,
...
@@ -269,7 +269,7 @@ template <int NDimSpatial,
bool
profile_convnd_bwd_data_impl
(
int
do_verification
,
bool
profile_convnd_bwd_data_impl
(
int
do_verification
,
int
init_method
,
int
init_method
,
bool
do_log
,
bool
do_log
,
int
nrepeat
,
bool
time_kernel
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
...
@@ -410,7 +410,8 @@ bool profile_convnd_bwd_data_impl(int do_verification,
...
@@ -410,7 +410,8 @@ bool profile_convnd_bwd_data_impl(int do_verification,
{
{
std
::
string
conv_name
=
conv_ptr
->
GetTypeString
();
std
::
string
conv_name
=
conv_ptr
->
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
nrepeat
);
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
flop
=
ck
::
utils
::
conv
::
get_flops
(
N
,
C
,
K
,
filter_spatial_lengths
,
output_spatial_lengths
);
ck
::
utils
::
conv
::
get_flops
(
N
,
C
,
K
,
filter_spatial_lengths
,
output_spatial_lengths
);
...
...
profiler/include/profile_gemm_bias_2d_impl.hpp
View file @
f9c478e2
...
@@ -65,7 +65,7 @@ template <typename ADataType,
...
@@ -65,7 +65,7 @@ template <typename ADataType,
void
profile_gemm_bias_2d_impl
(
int
do_verification
,
void
profile_gemm_bias_2d_impl
(
int
do_verification
,
int
init_method
,
int
init_method
,
bool
do_log
,
bool
do_log
,
int
nrepeat
,
bool
time_kernel
,
int
M
,
int
M
,
int
N
,
int
N
,
int
K
,
int
K
,
...
@@ -259,7 +259,8 @@ void profile_gemm_bias_2d_impl(int do_verification,
...
@@ -259,7 +259,8 @@ void profile_gemm_bias_2d_impl(int do_verification,
{
{
std
::
string
gemm_name
=
gemm_ptr
->
GetTypeString
();
std
::
string
gemm_name
=
gemm_ptr
->
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
nrepeat
);
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
...
...
profiler/include/profile_gemm_bias_relu_add_impl.hpp
View file @
f9c478e2
...
@@ -48,7 +48,7 @@ template <typename ADataType,
...
@@ -48,7 +48,7 @@ template <typename ADataType,
void
profile_gemm_bias_relu_add_impl
(
int
do_verification
,
void
profile_gemm_bias_relu_add_impl
(
int
do_verification
,
int
init_method
,
int
init_method
,
bool
do_log
,
bool
do_log
,
int
nrepeat
,
bool
time_kernel
,
int
M
,
int
M
,
int
N
,
int
N
,
int
K
,
int
K
,
...
@@ -232,7 +232,8 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
...
@@ -232,7 +232,8 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
{
{
std
::
string
gemm_name
=
gemm_ptr
->
GetTypeString
();
std
::
string
gemm_name
=
gemm_ptr
->
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
nrepeat
);
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
...
...
profiler/include/profile_gemm_bias_relu_impl.hpp
View file @
f9c478e2
...
@@ -48,7 +48,7 @@ template <typename ADataType,
...
@@ -48,7 +48,7 @@ template <typename ADataType,
void
profile_gemm_bias_relu_impl
(
int
do_verification
,
void
profile_gemm_bias_relu_impl
(
int
do_verification
,
int
init_method
,
int
init_method
,
bool
do_log
,
bool
do_log
,
int
nrepeat
,
bool
time_kernel
,
int
M
,
int
M
,
int
N
,
int
N
,
int
K
,
int
K
,
...
@@ -212,7 +212,8 @@ void profile_gemm_bias_relu_impl(int do_verification,
...
@@ -212,7 +212,8 @@ void profile_gemm_bias_relu_impl(int do_verification,
{
{
std
::
string
gemm_name
=
gemm_ptr
->
GetTypeString
();
std
::
string
gemm_name
=
gemm_ptr
->
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
nrepeat
);
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
...
...
profiler/include/profile_gemm_impl.hpp
View file @
f9c478e2
#pragma once
#pragma once
#include <iomanip>
#include <iomanip>
#include <iostream>
#include <typeinfo>
#include "check_err.hpp"
#include "check_err.hpp"
#include "config.hpp"
#include "config.hpp"
...
@@ -42,14 +44,10 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<De
...
@@ -42,14 +44,10 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<De
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances
(
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances
(
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
(
void
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
...
@@ -74,6 +72,21 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector<Devic
...
@@ -74,6 +72,21 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector<Devic
void
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
}
// namespace device_gemm_instance
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
@@ -85,13 +98,14 @@ namespace profiler {
...
@@ -85,13 +98,14 @@ namespace profiler {
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
CDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
>
typename
CLayout
>
void
profile_gemm_impl
(
int
do_verification
,
void
profile_gemm_impl
(
int
do_verification
,
int
init_method
,
int
init_method
,
bool
do_log
,
bool
do_log
,
int
nrepeat
,
bool
time_kernel
,
int
M
,
int
M
,
int
N
,
int
N
,
int
K
,
int
K
,
...
@@ -125,7 +139,11 @@ void profile_gemm_impl(int do_verification,
...
@@ -125,7 +139,11 @@ void profile_gemm_impl(int do_verification,
std
::
size_t
num_thread
=
1
;
std
::
size_t
num_thread
=
1
;
switch
(
init_method
)
switch
(
init_method
)
{
{
case
0
:
break
;
// case 0: break;
case
0
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{},
num_thread
);
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{},
num_thread
);
break
;
case
1
:
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
},
num_thread
);
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
},
num_thread
);
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
},
num_thread
);
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
},
num_thread
);
...
@@ -174,6 +192,9 @@ void profile_gemm_impl(int do_verification,
...
@@ -174,6 +192,9 @@ void profile_gemm_impl(int do_verification,
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances
(
gemm_ptrs
);
}
}
...
@@ -192,6 +213,9 @@ void profile_gemm_impl(int do_verification,
...
@@ -192,6 +213,9 @@ void profile_gemm_impl(int do_verification,
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances
(
gemm_ptrs
);
}
}
...
@@ -210,6 +234,9 @@ void profile_gemm_impl(int do_verification,
...
@@ -210,6 +234,9 @@ void profile_gemm_impl(int do_verification,
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
}
}
...
@@ -228,6 +255,9 @@ void profile_gemm_impl(int do_verification,
...
@@ -228,6 +255,9 @@ void profile_gemm_impl(int do_verification,
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances
(
gemm_ptrs
);
}
}
...
@@ -250,6 +280,9 @@ void profile_gemm_impl(int do_verification,
...
@@ -250,6 +280,9 @@ void profile_gemm_impl(int do_verification,
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
}
}
...
@@ -268,6 +301,9 @@ void profile_gemm_impl(int do_verification,
...
@@ -268,6 +301,9 @@ void profile_gemm_impl(int do_verification,
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
gemm_ptrs
);
...
@@ -289,6 +325,9 @@ void profile_gemm_impl(int do_verification,
...
@@ -289,6 +325,9 @@ void profile_gemm_impl(int do_verification,
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
gemm_ptrs
);
}
}
...
@@ -307,6 +346,9 @@ void profile_gemm_impl(int do_verification,
...
@@ -307,6 +346,9 @@ void profile_gemm_impl(int do_verification,
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
gemm_ptrs
);
}
}
...
@@ -353,28 +395,40 @@ void profile_gemm_impl(int do_verification,
...
@@ -353,28 +395,40 @@ void profile_gemm_impl(int do_verification,
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances
(
gemm_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances
(
gemm_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances
(
gemm_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances
(
gemm_ptrs
);
}
}
}
}
...
@@ -416,12 +470,13 @@ void profile_gemm_impl(int do_verification,
...
@@ -416,12 +470,13 @@ void profile_gemm_impl(int do_verification,
std
::
string
gemm_name
=
gemm_ptr
->
GetTypeString
();
std
::
string
gemm_name
=
gemm_ptr
->
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
nrepeat
);
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
M
+
sizeof
(
CDataType
)
*
M
*
N
;
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
@@ -457,8 +512,14 @@ void profile_gemm_impl(int do_verification,
...
@@ -457,8 +512,14 @@ void profile_gemm_impl(int do_verification,
bf16_to_f32_
(
b_k_n
,
b_f32_k_n
);
bf16_to_f32_
(
b_k_n
,
b_f32_k_n
);
bf16_to_f32_
(
c_m_n_device_result
,
c_m_n_device_f32_result
);
bf16_to_f32_
(
c_m_n_device_result
,
c_m_n_device_f32_result
);
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ReferenceGemm
<
float
,
float
,
float
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
float
,
float
,
float
,
float
,
AElementOp
,
BElementOp
,
CElementOp
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
@@ -490,6 +551,7 @@ void profile_gemm_impl(int do_verification,
...
@@ -490,6 +551,7 @@ void profile_gemm_impl(int do_verification,
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
CElementOp
>
;
CElementOp
>
;
...
@@ -522,12 +584,50 @@ void profile_gemm_impl(int do_verification,
...
@@ -522,12 +584,50 @@ void profile_gemm_impl(int do_verification,
}
}
else
else
{
{
std
::
cout
<<
"does not support this GEMM problem"
<<
std
::
endl
;
std
::
cout
<<
gemm_ptr
->
GetTypeString
()
<<
" does not support this GEMM problem"
<<
std
::
endl
;
}
}
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
if
constexpr
(
is_same
<
CDataType
,
float
>::
value
)
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_gemm_name
<<
std
::
endl
;
{
std
::
cout
<<
"Best Perf for datatype = f32"
;
}
else
if
constexpr
(
is_same
<
CDataType
,
half_t
>::
value
)
{
std
::
cout
<<
"Best Perf for datatype = f16"
;
}
else
if
constexpr
(
is_same
<
CDataType
,
bhalf_t
>::
value
)
{
std
::
cout
<<
"Best Perf for datatype = bf16"
;
}
else
if
constexpr
(
is_same
<
CDataType
,
int8_t
>::
value
)
{
std
::
cout
<<
"Best Perf for datatype = int8"
;
}
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
std
::
cout
<<
" ALayout = RowMajor"
;
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
)
{
std
::
cout
<<
" ALayout = ColumnMajor"
;
}
if
constexpr
(
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
std
::
cout
<<
" BLayout = RowMajor"
;
}
else
if
constexpr
(
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
)
{
std
::
cout
<<
" BLayout = ColumnMajor"
;
}
std
::
cout
<<
" M = "
<<
M
<<
" N = "
<<
N
<<
" K = "
<<
K
<<
" StrideA = "
<<
StrideA
<<
" StrideB = "
<<
StrideB
<<
" StrideC = "
<<
StrideC
<<
" : "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_gemm_name
<<
std
::
endl
;
}
}
}
// namespace profiler
}
// namespace profiler
...
...
profiler/include/profile_gemm_reduce_impl.hpp
View file @
f9c478e2
...
@@ -16,11 +16,21 @@ namespace tensor_operation {
...
@@ -16,11 +16,21 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
namespace
device_gemm_instance
{
namespace
device_gemm_instance
{
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
DPtrsGlobal
=
ck
::
Tuple
<
F32
*
,
F32
*>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
F32
,
F32
,
false
>
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
F32
,
F32
,
false
>
;
using
DInElementOps
=
ck
::
Tuple
<
Identity
,
Square
>
;
using
DOutElementOps
=
ck
::
Tuple
<
Identity
,
Identity
>
;
using
DeviceGemmReduceNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmReducePtr
<
using
DeviceGemmReduceNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmReducePtr
<
DPtrsGlobal
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
float
,
float
,
false
>>
;
DInElementOps
,
DOutElementOps
>
;
void
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances
(
void
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmReduceNoOpPtr
>&
);
std
::
vector
<
DeviceGemmReduceNoOpPtr
>&
);
...
@@ -52,7 +62,7 @@ template <typename ADataType,
...
@@ -52,7 +62,7 @@ template <typename ADataType,
bool
profile_gemm_reduce_impl
(
int
do_verification
,
bool
profile_gemm_reduce_impl
(
int
do_verification
,
int
init_method
,
int
init_method
,
bool
do_log
,
bool
do_log
,
int
nrepeat
,
bool
time_kernel
,
int
M
,
int
M
,
int
N
,
int
N
,
int
K
,
int
K
,
...
@@ -112,24 +122,35 @@ bool profile_gemm_reduce_impl(int do_verification,
...
@@ -112,24 +122,35 @@ bool profile_gemm_reduce_impl(int do_verification,
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
},
num_thread
);
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
},
num_thread
);
}
}
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
D0ReduceOp
=
ck
::
reduce
::
Add
<
float
>
;
using
D0ReduceOp
=
ck
::
reduce
::
Add
<
float
>
;
using
D1ReduceOp
=
ck
::
reduce
::
Add
<
float
>
;
using
D1ReduceOp
=
ck
::
reduce
::
Add
<
float
>
;
using
D1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
float
,
float
,
false
>
;
using
UnaryIdenticElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
float
,
float
,
false
>
;
const
auto
a_element_op
=
AElementOp
{};
using
UnarySquareElementOp
=
const
auto
b_element_op
=
BElementOp
{};
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
float
,
float
,
false
>
;
const
auto
c_element_op
=
CElementOp
{};
using
DxsInElementOps
=
ck
::
Tuple
<
UnaryIdenticElementOp
,
UnarySquareElementOp
>
;
const
auto
d0_reduce_op
=
D0ReduceOp
{};
using
DxsOutElementOps
=
ck
::
Tuple
<
UnaryIdenticElementOp
,
UnaryIdenticElementOp
>
;
const
auto
d1_reduce_op
=
D1ReduceOp
{};
const
auto
d1_element_op
=
D1ElementOp
{};
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
const
auto
dxs_in_element_op
=
DxsInElementOps
{};
const
auto
dxs_out_element_op
=
DxsOutElementOps
{};
const
auto
d0_reduce_op
=
D0ReduceOp
{};
const
auto
d1_reduce_op
=
D1ReduceOp
{};
if
(
do_verification
)
if
(
do_verification
)
{
{
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
BDataType
,
CDataType
,
DDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
@@ -149,7 +170,7 @@ bool profile_gemm_reduce_impl(int do_verification,
...
@@ -149,7 +170,7 @@ bool profile_gemm_reduce_impl(int do_verification,
float
d0_val
=
ck
::
type_convert
<
float
>
(
c_m_n_host_result
(
m
,
n
));
float
d0_val
=
ck
::
type_convert
<
float
>
(
c_m_n_host_result
(
m
,
n
));
float
d1_val
;
float
d1_val
;
d1_e
lement
_op
(
d1_val
,
d0_val
);
UnarySquareE
lement
Op
{}
(
d1_val
,
d0_val
);
d0_reduce_op
(
d0_acc
,
d0_val
);
d0_reduce_op
(
d0_acc
,
d0_val
);
d1_reduce_op
(
d1_acc
,
d1_val
);
d1_reduce_op
(
d1_acc
,
d1_val
);
}
}
...
@@ -165,6 +186,9 @@ bool profile_gemm_reduce_impl(int do_verification,
...
@@ -165,6 +186,9 @@ bool profile_gemm_reduce_impl(int do_verification,
DeviceMem
d0_device_buf
(
sizeof
(
DDataType
)
*
d0_m_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
d0_device_buf
(
sizeof
(
DDataType
)
*
d0_m_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
d1_device_buf
(
sizeof
(
DDataType
)
*
d1_m_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
d1_device_buf
(
sizeof
(
DDataType
)
*
d1_m_device_result
.
mDesc
.
GetElementSpace
());
auto
dxs_global
=
ck
::
make_tuple
(
static_cast
<
DDataType
*>
(
d0_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
d1_device_buf
.
GetDeviceBuffer
()));
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
...
@@ -226,8 +250,7 @@ bool profile_gemm_reduce_impl(int do_verification,
...
@@ -226,8 +250,7 @@ bool profile_gemm_reduce_impl(int do_verification,
gemm_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
gemm_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
d0_device_buf
.
GetDeviceBuffer
()),
dxs_global
,
static_cast
<
DDataType
*>
(
d1_device_buf
.
GetDeviceBuffer
()),
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -237,42 +260,25 @@ bool profile_gemm_reduce_impl(int do_verification,
...
@@ -237,42 +260,25 @@ bool profile_gemm_reduce_impl(int do_verification,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
d1_element_op
);
dxs_in_element_op
,
dxs_out_element_op
);
auto
invoker_ptr
=
gemm_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
gemm_ptr
->
MakeInvokerPointer
();
if
(
gemm_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
gemm_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
// warm up
// init DO, D1 to 0
invoker_ptr
->
Run
(
argument_ptr
.
get
());
d0_device_buf
.
SetZero
();
d1_device_buf
.
SetZero
();
// timing
float
total_time
=
0
;
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
// init DO, D1 to 0
d0_device_buf
.
SetZero
();
d1_device_buf
.
SetZero
();
KernelTimer
timer
;
timer
.
Start
();
invoker_ptr
->
Run
(
argument_ptr
.
get
());
timer
.
End
();
total_time
+=
timer
.
GetElapsedTime
();
}
float
ave_time
=
total_time
/
nrepeat
;
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
string
gemm_name
=
gemm_ptr
->
GetTypeString
();
std
::
string
gemm_name
=
gemm_ptr
->
GetTypeString
();
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
M
+
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
+
sizeof
(
CDataType
)
*
N
;
sizeof
(
CDataType
)
*
M
*
N
+
sizeof
(
CDataType
)
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
...
profiler/include/profile_grouped_gemm_impl.hpp
View file @
f9c478e2
...
@@ -43,19 +43,20 @@ namespace profiler {
...
@@ -43,19 +43,20 @@ namespace profiler {
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
CDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
>
typename
CLayout
>
void
profile_grouped_gemm_impl
(
int
do_verification
,
void
profile_grouped_gemm_impl
(
int
do_verification
,
int
init_method
,
int
init_method
,
bool
do_log
,
bool
do_log
,
int
nrepeat
,
bool
time_kernel
,
std
::
vector
<
int
>
Ms
,
const
std
::
vector
<
int
>
&
Ms
,
std
::
vector
<
int
>
Ns
,
const
std
::
vector
<
int
>
&
Ns
,
std
::
vector
<
int
>
Ks
,
const
std
::
vector
<
int
>
&
Ks
,
std
::
vector
<
int
>
StrideAs
,
const
std
::
vector
<
int
>
&
StrideAs
,
std
::
vector
<
int
>
StrideBs
,
const
std
::
vector
<
int
>
&
StrideBs
,
std
::
vector
<
int
>
StrideCs
)
const
std
::
vector
<
int
>
&
StrideCs
)
{
{
auto
f_host_tensor_descriptor
=
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
...
@@ -71,7 +72,7 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -71,7 +72,7 @@ void profile_grouped_gemm_impl(int do_verification,
}
}
};
};
in
t
group_count
=
Ms
.
size
();
std
::
size_
t
group_count
=
Ms
.
size
();
if
(
!
(
group_count
==
Ns
.
size
()
&&
group_count
==
Ks
.
size
()
&&
group_count
==
StrideAs
.
size
()
&&
if
(
!
(
group_count
==
Ns
.
size
()
&&
group_count
==
Ks
.
size
()
&&
group_count
==
StrideAs
.
size
()
&&
group_count
==
StrideBs
.
size
()
&&
group_count
==
StrideCs
.
size
()))
group_count
==
StrideBs
.
size
()
&&
group_count
==
StrideCs
.
size
()))
...
@@ -83,7 +84,7 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -83,7 +84,7 @@ void profile_grouped_gemm_impl(int do_verification,
std
::
vector
<
Tensor
<
BDataType
>>
b_k_n
;
std
::
vector
<
Tensor
<
BDataType
>>
b_k_n
;
std
::
vector
<
Tensor
<
CDataType
>>
c_m_n_device_results
;
std
::
vector
<
Tensor
<
CDataType
>>
c_m_n_device_results
;
for
(
in
t
i
=
0
;
i
<
Ms
.
size
()
;
i
++
)
for
(
std
::
size_
t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
a_m_k
.
push_back
(
a_m_k
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ks
[
i
],
StrideAs
[
i
],
ALayout
{})));
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ks
[
i
],
StrideAs
[
i
],
ALayout
{})));
...
@@ -144,7 +145,7 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -144,7 +145,7 @@ void profile_grouped_gemm_impl(int do_verification,
gemm_shapes
.
reserve
(
group_count
);
gemm_shapes
.
reserve
(
group_count
);
for
(
in
t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_
t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
a_device_buf
.
emplace_back
(
a_device_buf
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_m_k
[
i
].
mDesc
.
GetElementSpace
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_m_k
[
i
].
mDesc
.
GetElementSpace
()));
...
@@ -231,10 +232,11 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -231,10 +232,11 @@ void profile_grouped_gemm_impl(int do_verification,
{
{
std
::
string
gemm_name
=
gemm_ptr
->
GetTypeString
();
std
::
string
gemm_name
=
gemm_ptr
->
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
nrepeat
);
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
0
,
num_btype
=
0
;
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
in
t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
for
(
std
::
size_
t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
{
flop
+=
std
::
size_t
(
2
)
*
Ms
[
i
]
*
Ns
[
i
]
*
Ks
[
i
];
flop
+=
std
::
size_t
(
2
)
*
Ms
[
i
]
*
Ns
[
i
]
*
Ks
[
i
];
...
@@ -258,7 +260,7 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -258,7 +260,7 @@ void profile_grouped_gemm_impl(int do_verification,
if
(
do_verification
)
if
(
do_verification
)
{
{
for
(
in
t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
for
(
std
::
size_
t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
{
c_device_buf
[
i
]
->
FromDevice
(
c_m_n_device_results
[
i
].
mData
.
data
());
c_device_buf
[
i
]
->
FromDevice
(
c_m_n_device_results
[
i
].
mData
.
data
());
...
@@ -270,6 +272,7 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -270,6 +272,7 @@ void profile_grouped_gemm_impl(int do_verification,
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
CElementOp
>
;
CElementOp
>
;
...
...
profiler/include/profile_reduce_impl.hpp
View file @
f9c478e2
...
@@ -5,74 +5,77 @@
...
@@ -5,74 +5,77 @@
#include "device_reduce_instance.hpp"
#include "device_reduce_instance.hpp"
#include "reduction_enums.hpp"
#include "reduction_enums.hpp"
#include "host_reduction.hpp"
#include "host_reduction.hpp"
#include "host_common_util.hpp"
#include "host_tensor_generator.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
device_reduce_instance
{
namespace
device_reduce_instance
{
template
<
int
Rank
,
int
NumReduceDim
,
int
ReduceOpId
,
int
NanOpt
,
int
IndicesOpt
>
template
<
int
Rank
,
int
NumReduceDim
,
int
ReduceOpId
,
bool
PropagateNan
,
bool
UseIndex
>
struct
ReduceDescription
struct
ReduceDescription
{
{
static
constexpr
int
Rank_
=
Rank
;
static
constexpr
int
Rank_
=
Rank
;
static
constexpr
int
NumReduceDim_
=
NumReduceDim
;
static
constexpr
int
NumReduceDim_
=
NumReduceDim
;
static
constexpr
int
ReduceOpId_
=
ReduceOpId
;
static
constexpr
int
ReduceOpId_
=
ReduceOpId
;
static
constexpr
int
NanOpt_
=
NanOpt
;
static
constexpr
int
PropagateNan_
=
PropagateNan
;
static
constexpr
int
IndicesOpt_
=
IndicesOpt
;
static
constexpr
int
UseIndex_
=
UseIndex
;
};
};
using
reduce_description_instances
=
std
::
tuple
<
ReduceDescription
<
4
,
3
,
0
,
0
,
0
>
,
// for ADD
using
reduce_description_instances
=
ReduceDescription
<
4
,
4
,
0
,
0
,
0
>
,
std
::
tuple
<
ReduceDescription
<
4
,
3
,
0
,
false
,
false
>
,
// for ADD
ReduceDescription
<
4
,
1
,
0
,
0
,
0
>
,
ReduceDescription
<
4
,
4
,
0
,
false
,
false
>
,
ReduceDescription
<
2
,
1
,
0
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
0
,
false
,
false
>
,
ReduceDescription
<
2
,
1
,
0
,
false
,
false
>
,
ReduceDescription
<
4
,
3
,
5
,
0
,
0
>
,
// for AVG
ReduceDescription
<
4
,
4
,
5
,
0
,
0
>
,
ReduceDescription
<
4
,
3
,
5
,
false
,
false
>
,
// for AVG
ReduceDescription
<
4
,
1
,
5
,
0
,
0
>
,
ReduceDescription
<
4
,
4
,
5
,
false
,
false
>
,
ReduceDescription
<
2
,
1
,
5
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
5
,
false
,
false
>
,
ReduceDescription
<
2
,
1
,
5
,
false
,
false
>
,
ReduceDescription
<
4
,
3
,
7
,
0
,
0
>
,
// for NORM2
ReduceDescription
<
4
,
4
,
7
,
0
,
0
>
,
ReduceDescription
<
4
,
3
,
7
,
false
,
false
>
,
// for NORM2
ReduceDescription
<
4
,
1
,
7
,
0
,
0
>
,
ReduceDescription
<
4
,
4
,
7
,
false
,
false
>
,
ReduceDescription
<
2
,
1
,
7
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
7
,
false
,
false
>
,
ReduceDescription
<
2
,
1
,
7
,
false
,
false
>
,
ReduceDescription
<
4
,
3
,
2
,
0
,
0
>
,
// for MIN
ReduceDescription
<
4
,
4
,
2
,
0
,
0
>
,
ReduceDescription
<
4
,
3
,
2
,
false
,
false
>
,
// for MIN
ReduceDescription
<
4
,
1
,
2
,
0
,
0
>
,
ReduceDescription
<
4
,
4
,
2
,
false
,
false
>
,
ReduceDescription
<
2
,
1
,
2
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
2
,
false
,
false
>
,
ReduceDescription
<
4
,
3
,
3
,
0
,
0
>
,
// for MAX
ReduceDescription
<
2
,
1
,
2
,
false
,
false
>
,
ReduceDescription
<
4
,
4
,
3
,
0
,
0
>
,
ReduceDescription
<
4
,
3
,
3
,
false
,
false
>
,
// for MAX
ReduceDescription
<
4
,
1
,
3
,
0
,
0
>
,
ReduceDescription
<
4
,
4
,
3
,
false
,
false
>
,
ReduceDescription
<
2
,
1
,
3
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
3
,
false
,
false
>
,
ReduceDescription
<
4
,
3
,
4
,
0
,
0
>
,
// for AMAX
ReduceDescription
<
2
,
1
,
3
,
false
,
false
>
,
ReduceDescription
<
4
,
4
,
4
,
0
,
0
>
,
ReduceDescription
<
4
,
3
,
4
,
false
,
false
>
,
// for AMAX
ReduceDescription
<
4
,
1
,
4
,
0
,
0
>
,
ReduceDescription
<
4
,
4
,
4
,
false
,
false
>
,
ReduceDescription
<
2
,
1
,
4
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
4
,
false
,
false
>
,
ReduceDescription
<
2
,
1
,
4
,
false
,
false
>
,
ReduceDescription
<
4
,
3
,
2
,
0
,
1
>
,
// for MIN
ReduceDescription
<
4
,
4
,
2
,
0
,
1
>
,
ReduceDescription
<
4
,
3
,
2
,
false
,
true
>
,
// for MIN
ReduceDescription
<
4
,
1
,
2
,
0
,
1
>
,
ReduceDescription
<
4
,
4
,
2
,
false
,
true
>
,
ReduceDescription
<
2
,
1
,
2
,
0
,
1
>
,
ReduceDescription
<
4
,
1
,
2
,
false
,
true
>
,
ReduceDescription
<
4
,
3
,
3
,
0
,
1
>
,
// for MAX
ReduceDescription
<
2
,
1
,
2
,
false
,
true
>
,
ReduceDescription
<
4
,
4
,
3
,
0
,
1
>
,
ReduceDescription
<
4
,
3
,
3
,
false
,
true
>
,
// for MAX
ReduceDescription
<
4
,
1
,
3
,
0
,
1
>
,
ReduceDescription
<
4
,
4
,
3
,
false
,
true
>
,
ReduceDescription
<
2
,
1
,
3
,
0
,
1
>
,
ReduceDescription
<
4
,
1
,
3
,
false
,
true
>
,
ReduceDescription
<
4
,
3
,
4
,
0
,
1
>
,
// for AMAX
ReduceDescription
<
2
,
1
,
3
,
false
,
true
>
,
ReduceDescription
<
4
,
4
,
4
,
0
,
1
>
,
ReduceDescription
<
4
,
3
,
4
,
false
,
true
>
,
// for AMAX
ReduceDescription
<
4
,
1
,
4
,
0
,
1
>
,
ReduceDescription
<
4
,
4
,
4
,
false
,
true
>
,
ReduceDescription
<
2
,
1
,
4
,
0
,
1
>>
;
ReduceDescription
<
4
,
1
,
4
,
false
,
true
>
,
ReduceDescription
<
2
,
1
,
4
,
false
,
true
>>
;
template
<
typename
DescriptionType
>
template
<
typename
DescriptionType
>
bool
description_match
(
const
DescriptionType
&
description
,
bool
description_match
(
const
DescriptionType
&
description
,
int
Rank
,
int
Rank
,
const
std
::
vector
<
int
>&
reduceDims
,
const
std
::
vector
<
int
>&
reduceDims
,
ReduceTensorOp
ReduceOpId
,
ReduceTensorOp
ReduceOpId
,
Nan
Propagat
ion
NanOpt
,
bool
Propagat
eNan
,
ReduceTensorIndices
IndicesOpt
)
bool
UseIndex
)
{
{
if
(
description
.
Rank_
!=
Rank
||
description
.
ReduceOpId_
!=
static_cast
<
int
>
(
ReduceOpId
)
||
if
(
description
.
Rank_
!=
Rank
||
description
.
ReduceOpId_
!=
static_cast
<
int
>
(
ReduceOpId
)
||
description
.
Nan
Opt
_
!=
static_cast
<
int
>
(
Nan
Opt
)
||
description
.
Propagate
Nan_
!=
static_cast
<
int
>
(
Propagate
Nan
)
||
description
.
IndicesOpt
_
!=
static_cast
<
int
>
(
IndicesOpt
))
description
.
UseIndex
_
!=
static_cast
<
int
>
(
UseIndex
))
return
(
false
);
return
(
false
);
if
(
DescriptionType
::
NumReduceDim_
!=
reduceDims
.
size
())
if
(
DescriptionType
::
NumReduceDim_
!=
reduceDims
.
size
())
...
@@ -116,48 +119,18 @@ static inline std::vector<int> get_invariant_dims(const std::vector<int>& reduce
...
@@ -116,48 +119,18 @@ static inline std::vector<int> get_invariant_dims(const std::vector<int>& reduce
return
invariantDims
;
return
invariantDims
;
};
};
template
<
typename
T
>
static
void
dumpBufferToFile
(
const
char
*
fileName
,
T
*
data
,
size_t
dataNumItems
)
{
std
::
ofstream
outFile
(
fileName
,
std
::
ios
::
binary
);
if
(
outFile
)
{
outFile
.
write
(
reinterpret_cast
<
char
*>
(
data
),
dataNumItems
*
sizeof
(
T
));
outFile
.
close
();
std
::
cout
<<
"Write output to file "
<<
fileName
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"Could not open file "
<<
fileName
<<
" for writing"
<<
std
::
endl
;
}
};
// map the data type used by the GPU kernels to the corresponding type used by the host codes
template
<
typename
InType
>
struct
type_mapping
{
using
OutType
=
InType
;
};
template
<
>
struct
type_mapping
<
ck
::
half_t
>
{
using
OutType
=
half_float
::
half
;
};
template
<
typename
InDataType
,
template
<
typename
InDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
OutDataType
,
typename
OutDataType
,
int
Rank
,
int
Rank
,
int
NumReduceDim
,
int
NumReduceDim
,
ReduceTensorOp
ReduceOpId
,
ReduceTensorOp
ReduceOpId
,
Nan
Propagat
ion
NanOpt
,
bool
Propagat
eNan
,
ReduceTensorIndices
IndicesOpt
>
bool
UseIndex
>
void
profile_reduce_impl_impl
(
bool
do_verification
,
bool
profile_reduce_impl_impl
(
bool
do_verification
,
int
init_method
,
int
init_method
,
bool
do_log
,
bool
do_dumpout
,
bool
do_dumpout
,
int
nrepeat
,
bool
time_kernel
,
const
std
::
vector
<
size_t
>&
inLengths
,
const
std
::
vector
<
size_t
>&
inLengths
,
const
std
::
vector
<
int
>&
reduceDims
,
const
std
::
vector
<
int
>&
reduceDims
,
float
alpha
,
float
alpha
,
...
@@ -166,15 +139,13 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -166,15 +139,13 @@ void profile_reduce_impl_impl(bool do_verification,
using
namespace
ck
::
tensor_operation
::
device
;
using
namespace
ck
::
tensor_operation
::
device
;
using
namespace
ck
::
tensor_operation
::
device
::
device_reduce_instance
;
using
namespace
ck
::
tensor_operation
::
device
::
device_reduce_instance
;
using
namespace
ck
::
host_reduce
;
using
namespace
ck
::
host_reduce
;
using
ck
::
host_common
::
dumpBufferToFile
;
constexpr
bool
op_support_indices
=
constexpr
bool
op_support_indices
=
(
ReduceOpId
==
ReduceTensorOp
::
MIN
||
ReduceOpId
==
ReduceTensorOp
::
MAX
||
(
ReduceOpId
==
ReduceTensorOp
::
MIN
||
ReduceOpId
==
ReduceTensorOp
::
MAX
||
ReduceOpId
==
ReduceTensorOp
::
AMAX
);
ReduceOpId
==
ReduceTensorOp
::
AMAX
);
constexpr
bool
NeedIndices
=
constexpr
bool
OutputIndex
=
(
op_support_indices
&&
UseIndex
);
(
op_support_indices
&&
(
IndicesOpt
!=
ReduceTensorIndices
::
NO_INDICES
));
constexpr
bool
PropagateNan
=
(
NanOpt
==
NanPropagation
::
PROPAGATE_NAN
);
constexpr
bool
out_support_atomic_add
=
std
::
is_same
<
OutDataType
,
float
>::
value
;
constexpr
bool
out_support_atomic_add
=
std
::
is_same
<
OutDataType
,
float
>::
value
;
constexpr
bool
op_support_atomic_add
=
constexpr
bool
op_support_atomic_add
=
...
@@ -195,8 +166,7 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -195,8 +166,7 @@ void profile_reduce_impl_impl(bool do_verification,
(
op_support_indices
&&
!
std
::
is_same
<
AccDataType
,
float
>::
value
);
(
op_support_indices
&&
!
std
::
is_same
<
AccDataType
,
float
>::
value
);
// 1) The indices can only be used when the reduction operation is indexable
// 1) The indices can only be used when the reduction operation is indexable
constexpr
bool
invalid_reduce_3
=
constexpr
bool
invalid_reduce_3
=
(
!
op_support_indices
&&
UseIndex
);
(
!
op_support_indices
&&
IndicesOpt
!=
ReduceTensorIndices
::
NO_INDICES
);
// 1) If InDataType is int8_t, must use int8_t as AccDataType for indexable reduction operations
// 1) If InDataType is int8_t, must use int8_t as AccDataType for indexable reduction operations
// 2) If InDataType is int8_t, must use int32_t as AccDataType for non-indexable reduction
// 2) If InDataType is int8_t, must use int32_t as AccDataType for non-indexable reduction
...
@@ -219,6 +189,8 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -219,6 +189,8 @@ void profile_reduce_impl_impl(bool do_verification,
constexpr
bool
invalid_reduce
=
(
invalid_reduce_1
||
invalid_reduce_2
||
invalid_reduce_3
||
constexpr
bool
invalid_reduce
=
(
invalid_reduce_1
||
invalid_reduce_2
||
invalid_reduce_3
||
invalid_reduce_4
||
invalid_reduce_5
||
invalid_reduce_6
);
invalid_reduce_4
||
invalid_reduce_5
||
invalid_reduce_6
);
bool
pass
=
true
;
if
constexpr
(
!
invalid_reduce
)
if
constexpr
(
!
invalid_reduce
)
{
{
Tensor
<
InDataType
>
in
(
inLengths
);
Tensor
<
InDataType
>
in
(
inLengths
);
...
@@ -282,7 +254,7 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -282,7 +254,7 @@ void profile_reduce_impl_impl(bool do_verification,
if
(
beta
!=
0.0
f
)
if
(
beta
!=
0.0
f
)
out_dev
.
ToDevice
(
out
.
mData
.
data
());
out_dev
.
ToDevice
(
out
.
mData
.
data
());
size_t
indicesSizeInBytes
=
NeedIndices
?
out
.
mDesc
.
GetElementSize
()
*
sizeof
(
int
)
:
0
;
size_t
indicesSizeInBytes
=
OutputIndex
?
out
.
mDesc
.
GetElementSize
()
*
sizeof
(
int
)
:
0
;
DeviceMem
out_indices_dev
(
indicesSizeInBytes
);
DeviceMem
out_indices_dev
(
indicesSizeInBytes
);
...
@@ -295,29 +267,11 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -295,29 +267,11 @@ void profile_reduce_impl_impl(bool do_verification,
using
AccElementwiseOperation_0
=
using
AccElementwiseOperation_0
=
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
AccElementwiseOperation
;
using
InElementwiseOperation_1
=
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
false
>::
InElementwiseOperation
;
using
AccElementwiseOperation_1
=
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
false
>::
AccElementwiseOperation
;
using
InElementwiseOperation_2
=
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
false
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation_2
=
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
false
,
true
>::
AccElementwiseOperation
;
using
DeviceReduceInstPtr0
=
using
DeviceReduceInstPtr0
=
DeviceReducePtr
<
InElementwiseOperation_0
,
AccElementwiseOperation_0
>
;
DeviceReducePtr
<
InElementwiseOperation_0
,
AccElementwiseOperation_0
>
;
using
DeviceReduceInstPtr1
=
DeviceReducePtr
<
InElementwiseOperation_1
,
AccElementwiseOperation_1
>
;
using
DeviceReduceInstPtr2
=
DeviceReducePtr
<
InElementwiseOperation_2
,
AccElementwiseOperation_2
>
;
std
::
vector
<
DeviceReduceInstPtr0
>
reduce0_ptrs
;
std
::
vector
<
DeviceReduceInstPtr0
>
reduce0_ptrs
;
std
::
vector
<
DeviceReduceInstPtr1
>
reduce1_ptrs
;
std
::
vector
<
DeviceReduceInstPtr2
>
reduce2_ptrs
;
add_device_reduce_instance_threadwise
<
InDataType
,
add_device_reduce_instance_threadwise
<
InDataType
,
AccDataType
,
AccDataType
,
...
@@ -325,8 +279,8 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -325,8 +279,8 @@ void profile_reduce_impl_impl(bool do_verification,
Rank
,
Rank
,
NumReduceDim
,
NumReduceDim
,
ReduceOpId
,
ReduceOpId
,
Nan
Opt
,
Propagate
Nan
,
IndicesOpt
>
(
reduce0_ptrs
);
UseIndex
>
(
reduce0_ptrs
);
add_device_reduce_instance_blockwise
<
InDataType
,
add_device_reduce_instance_blockwise
<
InDataType
,
AccDataType
,
AccDataType
,
...
@@ -334,8 +288,8 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -334,8 +288,8 @@ void profile_reduce_impl_impl(bool do_verification,
Rank
,
Rank
,
NumReduceDim
,
NumReduceDim
,
ReduceOpId
,
ReduceOpId
,
Nan
Opt
,
Propagate
Nan
,
IndicesOpt
>
(
reduce0_ptrs
);
UseIndex
>
(
reduce0_ptrs
);
if
constexpr
(
use_atomic_add
)
if
constexpr
(
use_atomic_add
)
{
{
...
@@ -345,35 +299,11 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -345,35 +299,11 @@ void profile_reduce_impl_impl(bool do_verification,
Rank
,
Rank
,
NumReduceDim
,
NumReduceDim
,
ReduceOpId
,
ReduceOpId
,
Nan
Opt
,
Propagate
Nan
,
IndicesOpt
>
(
reduce0_ptrs
);
UseIndex
>
(
reduce0_ptrs
);
}
}
else
{
add_device_reduce_instance_multiblock_partial_reduce
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
ReduceOpId
,
NanOpt
,
IndicesOpt
>
(
reduce1_ptrs
);
};
// used for secondary reduction
if
(
reduce0_ptrs
.
empty
())
if
constexpr
(
!
use_atomic_add
)
{
add_device_reduce_instance_blockwise_second_call
<
AccDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
ReduceOpId
,
NanOpt
,
IndicesOpt
>
(
reduce2_ptrs
);
};
if
(
reduce0_ptrs
.
empty
()
&&
reduce1_ptrs
.
empty
())
{
{
throw
std
::
runtime_error
(
"Wrong! No device REDUCE instance found"
);
throw
std
::
runtime_error
(
"Wrong! No device REDUCE instance found"
);
};
};
...
@@ -387,23 +317,25 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -387,23 +317,25 @@ void profile_reduce_impl_impl(bool do_verification,
Rank
,
Rank
,
NumReduceDim
,
NumReduceDim
,
PropagateNan
,
PropagateNan
,
NeedIndices
>
OutputIndex
>
hostReduce
(
in
.
mDesc
,
out_ref
.
mDesc
,
invariantDims
,
reduceDims
);
hostReduce
(
in
.
mDesc
,
out_ref
.
mDesc
,
invariantDims
,
reduceDims
);
hostReduce
.
Run
(
hostReduce
.
Run
(
alpha
,
in
.
mData
.
data
(),
beta
,
out_ref
.
mData
.
data
(),
out_indices_ref
.
mData
.
data
());
alpha
,
in
.
mData
.
data
(),
beta
,
out_ref
.
mData
.
data
(),
out_indices_ref
.
mData
.
data
());
};
};
const
auto
i_inLengths
=
to_int_vector
(
inLengths
);
std
::
vector
<
ck
::
index_t
>
i_inLengths
;
const
auto
i_inStrides
=
to_int_vector
(
inStrides
);
std
::
vector
<
ck
::
index_t
>
i_inStrides
;
const
auto
i_outLengths
=
to_int_vector
(
outLengths
);
std
::
vector
<
ck
::
index_t
>
i_outLengths
;
const
auto
i_outStrides
=
to_int_vector
(
outStrides
);
std
::
vector
<
ck
::
index_t
>
i_outStrides
;
i_inLengths
.
assign
(
inLengths
.
begin
(),
inLengths
.
end
());
i_inStrides
.
assign
(
inStrides
.
begin
(),
inStrides
.
end
());
i_outLengths
.
assign
(
outLengths
.
begin
(),
outLengths
.
end
());
i_outStrides
.
assign
(
outStrides
.
begin
(),
outStrides
.
end
());
for
(
auto
&
reduce_ptr
:
reduce0_ptrs
)
for
(
auto
&
reduce_ptr
:
reduce0_ptrs
)
{
{
auto
wsSizeInBytes
=
reduce_ptr
->
GetWorkspaceSizeInBytes
(
i_inLengths
,
reduceDims
);
DeviceMem
ws_dev
(
wsSizeInBytes
);
InElementwiseOperation_0
in_elementwise_op_0
(
static_cast
<
int32_t
>
(
reduce_total_length
));
InElementwiseOperation_0
in_elementwise_op_0
(
static_cast
<
int32_t
>
(
reduce_total_length
));
AccElementwiseOperation_0
acc_elementwise_op_0
(
AccElementwiseOperation_0
acc_elementwise_op_0
(
...
@@ -417,9 +349,9 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -417,9 +349,9 @@ void profile_reduce_impl_impl(bool do_verification,
alpha
,
alpha
,
beta
,
beta
,
in_dev
.
GetDeviceBuffer
(),
in_dev
.
GetDeviceBuffer
(),
nullptr
,
out_dev
.
GetDeviceBuffer
(),
out_dev
.
GetDeviceBuffer
(),
out_indices_dev
.
GetDeviceBuffer
(),
out_indices_dev
.
GetDeviceBuffer
(),
ws_dev
.
GetDeviceBuffer
(),
in_elementwise_op_0
,
in_elementwise_op_0
,
acc_elementwise_op_0
);
acc_elementwise_op_0
);
...
@@ -430,7 +362,8 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -430,7 +362,8 @@ void profile_reduce_impl_impl(bool do_verification,
auto
invoker_ptr
=
reduce_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
reduce_ptr
->
MakeInvokerPointer
();
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
nrepeat
);
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
num_bytes
=
std
::
size_t
num_bytes
=
invariant_total_length
*
reduce_total_length
*
sizeof
(
InDataType
)
+
invariant_total_length
*
reduce_total_length
*
sizeof
(
InDataType
)
+
...
@@ -438,8 +371,9 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -438,8 +371,9 @@ void profile_reduce_impl_impl(bool do_verification,
float
gb_per_sec
=
num_bytes
/
1.E6
/
avg_time
;
float
gb_per_sec
=
num_bytes
/
1.E6
/
avg_time
;
std
::
cout
<<
"Perf: "
<<
avg_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
reduce_name
if
(
time_kernel
)
<<
std
::
endl
;
std
::
cout
<<
"Perf: "
<<
avg_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
reduce_name
<<
std
::
endl
;
if
(
gb_per_sec
>
best_gb_per_sec
)
if
(
gb_per_sec
>
best_gb_per_sec
)
{
{
...
@@ -449,22 +383,24 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -449,22 +383,24 @@ void profile_reduce_impl_impl(bool do_verification,
if
(
do_verification
)
if
(
do_verification
)
{
{
bool
single_pass
;
out_dev
.
FromDevice
(
out
.
mData
.
data
());
out_dev
.
FromDevice
(
out
.
mData
.
data
());
ck
::
utils
::
check_err
(
out
.
mData
,
out_ref
.
mData
);
single_pass
=
ck
::
utils
::
check_err
(
out
.
mData
,
out_ref
.
mData
);
if
(
NeedIndices
)
if
(
OutputIndex
)
{
{
out_indices_dev
.
FromDevice
(
out_indices
.
mData
.
data
());
out_indices_dev
.
FromDevice
(
out_indices
.
mData
.
data
());
ck
::
utils
::
check_err
(
out_indices
.
mData
,
out_indices_ref
.
mData
);
single_pass
=
single_pass
&&
;
ck
::
utils
::
check_err
(
out_indices
.
mData
,
out_indices_ref
.
mData
)
;
};
};
if
(
do_log
)
if
(
!
single_pass
)
{
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"out_host : "
,
out_ref
.
mData
,
","
)
std
::
cout
<<
"Fail Info: "
<<
reduce_ptr
->
GetTypeString
()
<<
std
::
endl
;
<<
std
::
endl
;
}
LogRangeAsType
<
float
>
(
std
::
cout
<<
"out_device: "
,
out
.
mData
,
","
)
<<
std
::
endl
;
}
;
pass
=
pass
&&
single_pass
;
};
};
if
(
do_dumpout
)
if
(
do_dumpout
)
...
@@ -473,7 +409,7 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -473,7 +409,7 @@ void profile_reduce_impl_impl(bool do_verification,
dumpBufferToFile
(
"dump_out.bin"
,
out
.
mData
.
data
(),
out
.
mDesc
.
GetElementSize
());
dumpBufferToFile
(
"dump_out.bin"
,
out
.
mData
.
data
(),
out
.
mDesc
.
GetElementSize
());
dumpBufferToFile
(
dumpBufferToFile
(
"dump_out_host.bin"
,
out_ref
.
mData
.
data
(),
out_ref
.
mDesc
.
GetElementSize
());
"dump_out_host.bin"
,
out_ref
.
mData
.
data
(),
out_ref
.
mDesc
.
GetElementSize
());
if
(
NeedIndices
)
if
(
OutputIndex
)
{
{
dumpBufferToFile
(
"dump_indices.bin"
,
dumpBufferToFile
(
"dump_indices.bin"
,
out_indices
.
mData
.
data
(),
out_indices
.
mData
.
data
(),
...
@@ -485,156 +421,34 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -485,156 +421,34 @@ void profile_reduce_impl_impl(bool do_verification,
};
};
};
};
for
(
auto
&
reduce_ptr
:
reduce1_ptrs
)
if
(
time_kernel
)
{
std
::
cout
<<
"Best Perf: "
<<
best_avg_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s"
auto
wsSizeInBytes
=
reduce_ptr
->
GetWorkspaceSizeInBytes
(
i_inLengths
,
reduceDims
);
<<
std
::
endl
;
DeviceMem
ws_dev
(
wsSizeInBytes
);
InElementwiseOperation_1
in_elementwise_op_1
(
static_cast
<
int32_t
>
(
reduce_total_length
));
AccElementwiseOperation_1
acc_elementwise_op_1
(
static_cast
<
int32_t
>
(
reduce_total_length
));
auto
argument_ptr
=
reduce_ptr
->
MakeArgumentPointer
(
i_inLengths
,
i_inStrides
,
i_outLengths
,
i_outStrides
,
reduceDims
,
alpha
,
beta
,
in_dev
.
GetDeviceBuffer
(),
out_dev
.
GetDeviceBuffer
(),
out_indices_dev
.
GetDeviceBuffer
(),
ws_dev
.
GetDeviceBuffer
(),
in_elementwise_op_1
,
acc_elementwise_op_1
);
if
(
!
reduce_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
continue
;
std
::
string
reduce_name
=
reduce_ptr
->
GetTypeString
();
auto
invoker_ptr
=
reduce_ptr
->
MakeInvokerPointer
();
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
nrepeat
);
std
::
size_t
num_bytes
=
invariant_total_length
*
reduce_total_length
*
sizeof
(
InDataType
)
+
invariant_total_length
*
sizeof
(
OutDataType
);
std
::
vector
<
int
>
inLengths2
=
reduce_ptr
->
GetWorkspace2dLengths
(
argument_ptr
.
get
());
std
::
vector
<
int
>
inStrides2
{
inLengths2
[
1
],
1
};
for
(
auto
&
reduce2_ptr
:
reduce2_ptrs
)
{
InElementwiseOperation_2
in_elementwise_op_2
(
static_cast
<
int32_t
>
(
reduce_total_length
));
AccElementwiseOperation_2
acc_elementwise_op_2
(
static_cast
<
int32_t
>
(
reduce_total_length
));
auto
argument2_ptr
=
reduce2_ptr
->
MakeArgumentPointer
(
inLengths2
,
inStrides2
,
i_outLengths
,
i_outStrides
,
reduceDims
,
alpha
,
beta
,
ws_dev
.
GetDeviceBuffer
(),
out_dev
.
GetDeviceBuffer
(),
out_indices_dev
.
GetDeviceBuffer
(),
ws_dev
.
GetDeviceBuffer
(),
in_elementwise_op_2
,
acc_elementwise_op_2
);
if
(
!
reduce2_ptr
->
IsSupportedArgument
(
argument2_ptr
.
get
()))
continue
;
std
::
string
reduce2_name
=
reduce2_ptr
->
GetTypeString
();
auto
invoker2_ptr
=
reduce2_ptr
->
MakeInvokerPointer
();
float
avg_time_2
=
invoker2_ptr
->
Run
(
argument2_ptr
.
get
(),
nrepeat
);
std
::
size_t
num_bytes_2
=
static_cast
<
size_t
>
(
inLengths2
[
0
])
*
inLengths2
[
1
]
*
sizeof
(
AccDataType
);
float
gb_per_sec
=
(
num_bytes
+
num_bytes_2
)
/
1.E6
/
(
avg_time
+
avg_time_2
);
std
::
cout
<<
"Perf: "
<<
(
avg_time
+
avg_time_2
)
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
reduce_name
<<
" => "
<<
reduce2_name
<<
std
::
endl
;
if
(
gb_per_sec
>
best_gb_per_sec
)
{
best_avg_time
=
avg_time
+
avg_time_2
;
best_gb_per_sec
=
gb_per_sec
;
}
if
(
do_verification
)
{
out_dev
.
FromDevice
(
out
.
mData
.
data
());
ck
::
utils
::
check_err
(
out
.
mData
,
out_ref
.
mData
);
if
(
NeedIndices
)
{
out_indices_dev
.
FromDevice
(
out_indices
.
mData
.
data
());
ck
::
utils
::
check_err
(
out_indices
.
mData
,
out_indices_ref
.
mData
);
;
};
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"out_host : "
,
out_ref
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"out_device: "
,
out
.
mData
,
","
)
<<
std
::
endl
;
}
}
if
(
do_dumpout
)
{
dumpBufferToFile
(
"dump_in.bin"
,
in
.
mData
.
data
(),
in
.
mDesc
.
GetElementSize
());
dumpBufferToFile
(
"dump_out.bin"
,
out
.
mData
.
data
(),
out
.
mDesc
.
GetElementSize
());
dumpBufferToFile
(
"dump_out_host.bin"
,
out_ref
.
mData
.
data
(),
out_ref
.
mDesc
.
GetElementSize
());
if
(
NeedIndices
)
{
dumpBufferToFile
(
"dump_indices.bin"
,
out_indices
.
mData
.
data
(),
out_indices
.
mDesc
.
GetElementSize
());
dumpBufferToFile
(
"dump_indices_host.bin"
,
out_indices_ref
.
mData
.
data
(),
out_indices_ref
.
mDesc
.
GetElementSize
());
};
};
};
};
std
::
cout
<<
"Best Perf: "
<<
best_avg_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
}
}
else
else
{
{
std
::
cout
<<
"The requested reduction operation is not supported, please check !!!"
std
::
cout
<<
"The requested reduction operation is not supported, please check !!!"
<<
std
::
endl
;
<<
std
::
endl
;
};
};
return
pass
;
};
};
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
>
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
>
void
profile_reduce_impl
(
bool
do_verification
,
bool
profile_reduce_impl
(
bool
do_verification
,
int
init_method
,
int
init_method
,
bool
do_log
,
bool
do_dumpout
,
bool
do_dumpout
,
int
nrepeat
,
bool
time_kernel
,
const
std
::
vector
<
size_t
>&
inLengths
,
const
std
::
vector
<
size_t
>&
inLengths
,
const
std
::
vector
<
int
>&
reduceDims
,
const
std
::
vector
<
int
>&
reduceDims
,
ReduceTensorOp
ReduceOpId
,
ReduceTensorOp
ReduceOpId
,
Nan
Propagat
ion
NanOpt
,
bool
Propagat
eNan
,
ReduceTensorIndices
IndicesOpt
,
bool
UseIndex
,
float
alpha
,
float
alpha
,
float
beta
)
float
beta
)
{
{
bool
matched
=
false
;
bool
matched
=
false
;
bool
pass
=
true
;
using
tuple_of_description_instances
=
using
tuple_of_description_instances
=
tensor_operation
::
device
::
device_reduce_instance
::
reduce_description_instances
;
tensor_operation
::
device
::
device_reduce_instance
::
reduce_description_instances
;
...
@@ -648,29 +462,30 @@ void profile_reduce_impl(bool do_verification,
...
@@ -648,29 +462,30 @@ void profile_reduce_impl(bool do_verification,
using
descType
=
remove_cvref_t
<
decltype
(
std
::
get
<
i
>
(
tuple_object
))
>
;
using
descType
=
remove_cvref_t
<
decltype
(
std
::
get
<
i
>
(
tuple_object
))
>
;
if
(
!
description_match
(
if
(
!
description_match
(
descType
{},
inLengths
.
size
(),
reduceDims
,
ReduceOpId
,
NanOpt
,
IndicesOpt
))
descType
{},
inLengths
.
size
(),
reduceDims
,
ReduceOpId
,
PropagateNan
,
UseIndex
))
return
;
return
;
profile_reduce_impl_impl
<
InDataType
,
pass
=
pass
&&
AccDataType
,
profile_reduce_impl_impl
<
InDataType
,
OutDataType
,
AccDataType
,
descType
::
Rank_
,
OutDataType
,
descType
::
NumReduceDim_
,
descType
::
Rank_
,
static_cast
<
ReduceTensorOp
>
(
descType
::
ReduceOpId_
),
descType
::
NumReduceDim_
,
static_cast
<
NanPropagation
>
(
descType
::
NanOpt_
),
static_cast
<
ReduceTensorOp
>
(
descType
::
ReduceOpId_
),
static_cast
<
ReduceTensorIndices
>
(
descType
::
IndicesOpt_
)
>
(
static_cast
<
bool
>
(
descType
::
PropagateNan_
),
do_verification
,
static_cast
<
bool
>
(
descType
::
UseIndex_
)
>
(
do_verification
,
init_method
,
init_method
,
do_log
,
do_dumpout
,
do_dumpout
,
time_kernel
,
nrepeat
,
inLengths
,
inLengths
,
reduceDims
,
reduceDims
,
alpha
,
alpha
,
beta
);
beta
);
matched
=
true
;
matched
=
true
;
});
});
return
pass
;
};
};
}
// namespace profiler
}
// namespace profiler
...
...
profiler/src/profile_batched_gemm.cpp
View file @
f9c478e2
...
@@ -48,8 +48,8 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -48,8 +48,8 @@ int profile_batched_gemm(int argc, char* argv[])
printf
(
" 3: A[g, k, m] * B[g, n, k] = C[g, m, n])
\n
"
);
printf
(
" 3: A[g, k, m] * B[g, n, k] = C[g, m, n])
\n
"
);
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg
8
: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg
6
: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg7:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount
\n
"
);
printf
(
"arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount
\n
"
);
exit
(
1
);
exit
(
1
);
}
}
...
@@ -59,7 +59,7 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -59,7 +59,7 @@ int profile_batched_gemm(int argc, char* argv[])
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
7
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
7
]);
const
int
M
=
std
::
stoi
(
argv
[
8
]);
const
int
M
=
std
::
stoi
(
argv
[
8
]);
const
int
N
=
std
::
stoi
(
argv
[
9
]);
const
int
N
=
std
::
stoi
(
argv
[
9
]);
...
@@ -82,7 +82,7 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -82,7 +82,7 @@ int profile_batched_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -102,7 +102,7 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -102,7 +102,7 @@ int profile_batched_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -122,7 +122,7 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -122,7 +122,7 @@ int profile_batched_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -142,7 +142,7 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -142,7 +142,7 @@ int profile_batched_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -162,7 +162,7 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -162,7 +162,7 @@ int profile_batched_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -182,7 +182,7 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -182,7 +182,7 @@ int profile_batched_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -202,7 +202,7 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -202,7 +202,7 @@ int profile_batched_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -222,7 +222,7 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -222,7 +222,7 @@ int profile_batched_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -242,7 +242,7 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -242,7 +242,7 @@ int profile_batched_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -262,7 +262,7 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -262,7 +262,7 @@ int profile_batched_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -282,7 +282,7 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -282,7 +282,7 @@ int profile_batched_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -302,7 +302,7 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -302,7 +302,7 @@ int profile_batched_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -322,7 +322,7 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -322,7 +322,7 @@ int profile_batched_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -342,7 +342,7 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -342,7 +342,7 @@ int profile_batched_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -362,7 +362,7 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -362,7 +362,7 @@ int profile_batched_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -382,7 +382,7 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -382,7 +382,7 @@ int profile_batched_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -396,5 +396,5 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -396,5 +396,5 @@ int profile_batched_gemm(int argc, char* argv[])
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
}
}
return
1
;
return
0
;
}
}
profiler/src/profile_batched_gemm_reduce.cpp
View file @
f9c478e2
...
@@ -33,8 +33,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
...
@@ -33,8 +33,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
printf
(
" 3: A[k, m] * B[n, k] = C[m, n])
\n
"
);
printf
(
" 3: A[k, m] * B[n, k] = C[m, n])
\n
"
);
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg
8
: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg
6
: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg7:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount
\n
"
);
printf
(
"arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount
\n
"
);
printf
(
"arg15: split k into mulitiple batch
\n
"
);
printf
(
"arg15: split k into mulitiple batch
\n
"
);
exit
(
1
);
exit
(
1
);
...
@@ -45,7 +45,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
...
@@ -45,7 +45,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
7
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
7
]);
const
int
M
=
std
::
stoi
(
argv
[
8
]);
const
int
M
=
std
::
stoi
(
argv
[
8
]);
const
int
N
=
std
::
stoi
(
argv
[
9
]);
const
int
N
=
std
::
stoi
(
argv
[
9
]);
...
@@ -69,7 +69,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
...
@@ -69,7 +69,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -91,7 +91,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
...
@@ -91,7 +91,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -113,7 +113,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
...
@@ -113,7 +113,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -135,7 +135,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
...
@@ -135,7 +135,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -149,5 +149,5 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
...
@@ -149,5 +149,5 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
throw
std
::
runtime_error
(
"wrong! this data_type & layout is not implemented"
);
throw
std
::
runtime_error
(
"wrong! this data_type & layout is not implemented"
);
}
}
return
1
;
return
0
;
}
}
profiler/src/profile_conv_bwd_data.cpp
deleted
100644 → 0
View file @
7d85d04a
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "profile_conv_bwd_data_impl.hpp"
enum
struct
ConvDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
BF16_BF16_BF16
,
// 2
INT8_INT8_INT8
,
// 3
};
enum
struct
ConvInputLayout
{
NCHW
,
// 0
NHWC
,
// 1
};
enum
struct
ConvWeightLayout
{
KCYX
,
// 0
KYXC
,
// 1
};
enum
struct
ConvOutputLayout
{
NKHW
,
// 0
NHWK
,
// 1
};
int
profile_conv_bwd_data
(
int
argc
,
char
*
argv
[])
{
if
(
argc
!=
25
)
{
printf
(
"arg1: tensor operation (conv_bwd: BackwardConvolution)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16)
\n
"
);
printf
(
"arg3: input tensor layout (0: NCHW; 1: NHWC)
\n
"
);
printf
(
"arg4: weight tensor layout (0: KCYX; 1: KYXC)
\n
"
);
printf
(
"arg5: output tensor layout (0: NKHW; 1: NHWK)
\n
"
);
printf
(
"arg6: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg7: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg9: run kernel # of times (>1)
\n
"
);
printf
(
"arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
exit
(
1
);
}
const
auto
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
auto
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
9
]);
const
ck
::
index_t
N
=
std
::
stoi
(
argv
[
10
]);
const
ck
::
index_t
K
=
std
::
stoi
(
argv
[
11
]);
const
ck
::
index_t
C
=
std
::
stoi
(
argv
[
12
]);
const
ck
::
index_t
Y
=
std
::
stoi
(
argv
[
13
]);
const
ck
::
index_t
X
=
std
::
stoi
(
argv
[
14
]);
const
ck
::
index_t
Hi
=
std
::
stoi
(
argv
[
15
]);
const
ck
::
index_t
Wi
=
std
::
stoi
(
argv
[
16
]);
const
ck
::
index_t
conv_stride_h
=
std
::
stoi
(
argv
[
17
]);
const
ck
::
index_t
conv_stride_w
=
std
::
stoi
(
argv
[
18
]);
const
ck
::
index_t
conv_dilation_h
=
std
::
stoi
(
argv
[
19
]);
const
ck
::
index_t
conv_dilation_w
=
std
::
stoi
(
argv
[
20
]);
const
ck
::
index_t
in_left_pad_h
=
std
::
stoi
(
argv
[
21
]);
const
ck
::
index_t
in_left_pad_w
=
std
::
stoi
(
argv
[
22
]);
const
ck
::
index_t
in_right_pad_h
=
std
::
stoi
(
argv
[
23
]);
const
ck
::
index_t
in_right_pad_w
=
std
::
stoi
(
argv
[
24
]);
const
ck
::
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
ck
::
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
YEff
)
/
conv_stride_h
+
1
;
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
1
;
if
(
data_type
==
ConvDataType
::
F32_F32_F32
&&
in_layout
==
ConvInputLayout
::
NHWC
&&
wei_layout
==
ConvWeightLayout
::
KYXC
&&
out_layout
==
ConvOutputLayout
::
NHWK
)
{
ck
::
profiler
::
profile_conv_bwd_data_impl
<
2
,
float
,
float
,
float
,
float
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
N
,
K
,
C
,
std
::
vector
<
ck
::
index_t
>
{
Hi
,
Wi
},
std
::
vector
<
ck
::
index_t
>
{
Y
,
X
},
std
::
vector
<
ck
::
index_t
>
{
Ho
,
Wo
},
std
::
vector
<
ck
::
index_t
>
{
conv_stride_h
,
conv_stride_w
},
std
::
vector
<
ck
::
index_t
>
{
conv_dilation_h
,
conv_dilation_w
},
std
::
vector
<
ck
::
index_t
>
{
in_left_pad_h
,
in_left_pad_w
},
std
::
vector
<
ck
::
index_t
>
{
in_right_pad_h
,
in_right_pad_w
});
}
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
&&
in_layout
==
ConvInputLayout
::
NHWC
&&
wei_layout
==
ConvWeightLayout
::
KYXC
&&
out_layout
==
ConvOutputLayout
::
NHWK
)
{
ck
::
profiler
::
profile_conv_bwd_data_impl
<
2
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
N
,
K
,
C
,
std
::
vector
<
ck
::
index_t
>
{
Hi
,
Wi
},
std
::
vector
<
ck
::
index_t
>
{
Y
,
X
},
std
::
vector
<
ck
::
index_t
>
{
Ho
,
Wo
},
std
::
vector
<
ck
::
index_t
>
{
conv_stride_h
,
conv_stride_w
},
std
::
vector
<
ck
::
index_t
>
{
conv_dilation_h
,
conv_dilation_w
},
std
::
vector
<
ck
::
index_t
>
{
in_left_pad_h
,
in_left_pad_w
},
std
::
vector
<
ck
::
index_t
>
{
in_right_pad_h
,
in_right_pad_w
});
}
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
&&
in_layout
==
ConvInputLayout
::
NHWC
&&
wei_layout
==
ConvWeightLayout
::
KYXC
&&
out_layout
==
ConvOutputLayout
::
NHWK
)
{
ck
::
profiler
::
profile_conv_bwd_data_impl
<
2
,
uint16_t
,
uint16_t
,
uint16_t
,
float
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
N
,
K
,
C
,
std
::
vector
<
ck
::
index_t
>
{
Hi
,
Wi
},
std
::
vector
<
ck
::
index_t
>
{
Y
,
X
},
std
::
vector
<
ck
::
index_t
>
{
Ho
,
Wo
},
std
::
vector
<
ck
::
index_t
>
{
conv_stride_h
,
conv_stride_w
},
std
::
vector
<
ck
::
index_t
>
{
conv_dilation_h
,
conv_dilation_w
},
std
::
vector
<
ck
::
index_t
>
{
in_left_pad_h
,
in_left_pad_w
},
std
::
vector
<
ck
::
index_t
>
{
in_right_pad_h
,
in_right_pad_w
});
}
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
&&
in_layout
==
ConvInputLayout
::
NHWC
&&
wei_layout
==
ConvWeightLayout
::
KYXC
&&
out_layout
==
ConvOutputLayout
::
NHWK
)
{
ck
::
profiler
::
profile_conv_bwd_data_impl
<
2
,
int8_t
,
int8_t
,
int8_t
,
int32_t
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
N
,
K
,
C
,
std
::
vector
<
ck
::
index_t
>
{
Hi
,
Wi
},
std
::
vector
<
ck
::
index_t
>
{
Y
,
X
},
std
::
vector
<
ck
::
index_t
>
{
Ho
,
Wo
},
std
::
vector
<
ck
::
index_t
>
{
conv_stride_h
,
conv_stride_w
},
std
::
vector
<
ck
::
index_t
>
{
conv_dilation_h
,
conv_dilation_w
},
std
::
vector
<
ck
::
index_t
>
{
in_left_pad_h
,
in_left_pad_w
},
std
::
vector
<
ck
::
index_t
>
{
in_right_pad_h
,
in_right_pad_w
});
}
else
{
throw
std
::
runtime_error
(
"wrong! this Conv data_type & layout is not implemented"
);
}
return
1
;
}
profiler/src/profile_conv_bwd_weight.cpp
View file @
f9c478e2
...
@@ -58,7 +58,7 @@ int profile_conv_bwd_weight(int argc, char* argv[])
...
@@ -58,7 +58,7 @@ int profile_conv_bwd_weight(int argc, char* argv[])
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
9
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
9
]);
const
ck
::
index_t
N
=
std
::
stoi
(
argv
[
10
]);
const
ck
::
index_t
N
=
std
::
stoi
(
argv
[
10
]);
const
ck
::
index_t
K
=
std
::
stoi
(
argv
[
11
]);
const
ck
::
index_t
K
=
std
::
stoi
(
argv
[
11
]);
...
@@ -98,7 +98,7 @@ int profile_conv_bwd_weight(int argc, char* argv[])
...
@@ -98,7 +98,7 @@ int profile_conv_bwd_weight(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
N
,
N
,
K
,
K
,
C
,
C
,
...
@@ -124,7 +124,7 @@ int profile_conv_bwd_weight(int argc, char* argv[])
...
@@ -124,7 +124,7 @@ int profile_conv_bwd_weight(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
N
,
N
,
K
,
K
,
C
,
C
,
...
@@ -142,5 +142,5 @@ int profile_conv_bwd_weight(int argc, char* argv[])
...
@@ -142,5 +142,5 @@ int profile_conv_bwd_weight(int argc, char* argv[])
throw
std
::
runtime_error
(
"wrong! this Conv data_type & layout is not implemented"
);
throw
std
::
runtime_error
(
"wrong! this Conv data_type & layout is not implemented"
);
}
}
return
1
;
return
0
;
}
}
profiler/src/profile_conv_fwd_bias_relu.cpp
View file @
f9c478e2
...
@@ -42,7 +42,7 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
...
@@ -42,7 +42,7 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
printf
(
"arg6: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg6: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg7: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg7: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg9:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg9:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
printf
(
"arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
"RightPx
\n
"
);
exit
(
1
);
exit
(
1
);
...
@@ -55,7 +55,7 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
...
@@ -55,7 +55,7 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
9
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
9
]);
const
ck
::
index_t
N
=
std
::
stoi
(
argv
[
10
]);
const
ck
::
index_t
N
=
std
::
stoi
(
argv
[
10
]);
const
ck
::
index_t
K
=
std
::
stoi
(
argv
[
11
]);
const
ck
::
index_t
K
=
std
::
stoi
(
argv
[
11
]);
...
@@ -93,7 +93,7 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
...
@@ -93,7 +93,7 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
N
,
N
,
K
,
K
,
C
,
C
,
...
@@ -110,5 +110,5 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
...
@@ -110,5 +110,5 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
throw
std
::
runtime_error
(
"wrong! data_type & layout for this operator is not implemented"
);
throw
std
::
runtime_error
(
"wrong! data_type & layout for this operator is not implemented"
);
}
}
return
1
;
return
0
;
}
}
profiler/src/profile_conv_fwd_bias_relu_add.cpp
View file @
f9c478e2
...
@@ -43,7 +43,7 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
...
@@ -43,7 +43,7 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
printf
(
"arg6: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg6: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg7: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg7: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg9:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg9:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
printf
(
"arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
"RightPx
\n
"
);
exit
(
1
);
exit
(
1
);
...
@@ -56,7 +56,7 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
...
@@ -56,7 +56,7 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
9
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
9
]);
const
ck
::
index_t
N
=
std
::
stoi
(
argv
[
10
]);
const
ck
::
index_t
N
=
std
::
stoi
(
argv
[
10
]);
const
ck
::
index_t
K
=
std
::
stoi
(
argv
[
11
]);
const
ck
::
index_t
K
=
std
::
stoi
(
argv
[
11
]);
...
@@ -94,7 +94,7 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
...
@@ -94,7 +94,7 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
N
,
N
,
K
,
K
,
C
,
C
,
...
@@ -111,5 +111,5 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
...
@@ -111,5 +111,5 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
throw
std
::
runtime_error
(
"wrong! data_type & layout for this operator is not implemented"
);
throw
std
::
runtime_error
(
"wrong! data_type & layout for this operator is not implemented"
);
}
}
return
1
;
return
0
;
}
}
profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp
View file @
f9c478e2
...
@@ -43,7 +43,7 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[])
...
@@ -43,7 +43,7 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[])
printf
(
"arg6: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg6: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg7: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg7: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg9:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg9:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
printf
(
"arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
"RightPx
\n
"
);
exit
(
1
);
exit
(
1
);
...
@@ -56,7 +56,7 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[])
...
@@ -56,7 +56,7 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[])
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
9
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
9
]);
const
ck
::
index_t
N
=
std
::
stoi
(
argv
[
10
]);
const
ck
::
index_t
N
=
std
::
stoi
(
argv
[
10
]);
const
ck
::
index_t
K
=
std
::
stoi
(
argv
[
11
]);
const
ck
::
index_t
K
=
std
::
stoi
(
argv
[
11
]);
...
@@ -95,7 +95,7 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[])
...
@@ -95,7 +95,7 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
N
,
N
,
K
,
K
,
C
,
C
,
...
@@ -112,5 +112,5 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[])
...
@@ -112,5 +112,5 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[])
throw
std
::
runtime_error
(
"wrong! data_type & layout for this operator is not implemented"
);
throw
std
::
runtime_error
(
"wrong! data_type & layout for this operator is not implemented"
);
}
}
return
1
;
return
0
;
}
}
profiler/src/profile_convnd_bwd_data.cpp
View file @
f9c478e2
...
@@ -39,40 +39,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[],
...
@@ -39,40 +39,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[],
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
ck
::
utils
::
conv
::
ConvParams
params
;
ck
::
utils
::
conv
::
ConvParams
params
;
params
.
num_dim_spatial
=
num_dim_spatial
;
params
.
num_dim_spatial
_
=
num_dim_spatial
;
params
.
N
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
N
_
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
K
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
K
_
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
C
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
C
_
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
filter_spatial_lengths
.
resize
(
num_dim_spatial
);
params
.
filter_spatial_lengths
_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
{
params
.
filter_spatial_lengths
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
filter_spatial_lengths
_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
}
params
.
input_spatial_lengths
.
resize
(
num_dim_spatial
);
params
.
input_spatial_lengths
_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
{
params
.
input_spatial_lengths
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
input_spatial_lengths
_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
}
params
.
conv_filter_strides
.
resize
(
num_dim_spatial
);
params
.
conv_filter_strides
_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
{
params
.
conv_filter_strides
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
conv_filter_strides
_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
}
params
.
conv_filter_dilations
.
resize
(
num_dim_spatial
);
params
.
conv_filter_dilations
_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
{
params
.
conv_filter_dilations
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
conv_filter_dilations
_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
}
params
.
input_left_pads
.
resize
(
num_dim_spatial
);
params
.
input_left_pads
_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
{
params
.
input_left_pads
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
input_left_pads
_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
}
params
.
input_right_pads
.
resize
(
num_dim_spatial
);
params
.
input_right_pads
_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
{
params
.
input_right_pads
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
input_right_pads
_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
}
return
params
;
return
params
;
...
@@ -95,7 +95,7 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
...
@@ -95,7 +95,7 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
printf
(
"arg6: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg6: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg7: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg7: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg9:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg9:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
printf
(
"arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
"RightPx
\n
"
);
return
1
;
return
1
;
...
@@ -108,7 +108,7 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
...
@@ -108,7 +108,7 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
9
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
9
]);
ck
::
utils
::
conv
::
ConvParams
params
=
parse_conv_params
(
num_dim_spatial
,
argv
,
preParams
);
ck
::
utils
::
conv
::
ConvParams
params
=
parse_conv_params
(
num_dim_spatial
,
argv
,
preParams
);
...
@@ -132,17 +132,17 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
...
@@ -132,17 +132,17 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
params
.
N
,
params
.
N
_
,
params
.
K
,
params
.
K
_
,
params
.
C
,
params
.
C
_
,
params
.
input_spatial_lengths
,
params
.
input_spatial_lengths
_
,
params
.
filter_spatial_lengths
,
params
.
filter_spatial_lengths
_
,
params
.
GetOutputSpatialLengths
(),
params
.
GetOutputSpatialLengths
(),
params
.
conv_filter_strides
,
params
.
conv_filter_strides
_
,
params
.
conv_filter_dilations
,
params
.
conv_filter_dilations
_
,
params
.
input_left_pads
,
params
.
input_left_pads
_
,
params
.
input_right_pads
);
params
.
input_right_pads
_
);
break
;
break
;
case
2
:
case
2
:
...
@@ -157,17 +157,17 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
...
@@ -157,17 +157,17 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
params
.
N
,
params
.
N
_
,
params
.
K
,
params
.
K
_
,
params
.
C
,
params
.
C
_
,
params
.
input_spatial_lengths
,
params
.
input_spatial_lengths
_
,
params
.
filter_spatial_lengths
,
params
.
filter_spatial_lengths
_
,
params
.
GetOutputSpatialLengths
(),
params
.
GetOutputSpatialLengths
(),
params
.
conv_filter_strides
,
params
.
conv_filter_strides
_
,
params
.
conv_filter_dilations
,
params
.
conv_filter_dilations
_
,
params
.
input_left_pads
,
params
.
input_left_pads
_
,
params
.
input_right_pads
);
params
.
input_right_pads
_
);
break
;
break
;
case
3
:
case
3
:
...
@@ -182,17 +182,17 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
...
@@ -182,17 +182,17 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
params
.
N
,
params
.
N
_
,
params
.
K
,
params
.
K
_
,
params
.
C
,
params
.
C
_
,
params
.
input_spatial_lengths
,
params
.
input_spatial_lengths
_
,
params
.
filter_spatial_lengths
,
params
.
filter_spatial_lengths
_
,
params
.
GetOutputSpatialLengths
(),
params
.
GetOutputSpatialLengths
(),
params
.
conv_filter_strides
,
params
.
conv_filter_strides
_
,
params
.
conv_filter_dilations
,
params
.
conv_filter_dilations
_
,
params
.
input_left_pads
,
params
.
input_left_pads
_
,
params
.
input_right_pads
);
params
.
input_right_pads
_
);
break
;
break
;
default:
break
;
default:
break
;
...
...
profiler/src/profile_convnd_fwd.cpp
View file @
f9c478e2
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include <vector>
#include <vector>
#include <half.hpp>
#include <half.hpp>
#include "conv_
fwd_
util.hpp"
#include "conv_util.hpp"
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
#include "fill.hpp"
#include "fill.hpp"
#include "profile_convnd_fwd.hpp"
#include "profile_convnd_fwd.hpp"
...
@@ -119,7 +119,7 @@ template <int NDim,
...
@@ -119,7 +119,7 @@ template <int NDim,
void
profile_convnd_instances_impl
(
const
ck
::
utils
::
conv
::
ConvParams
&
params
,
void
profile_convnd_instances_impl
(
const
ck
::
utils
::
conv
::
ConvParams
&
params
,
bool
do_verification
,
bool
do_verification
,
bool
do_log
,
bool
do_log
,
int
nrepeat
,
bool
time_kernel
,
int
init_method
,
int
init_method
,
ConvLayouts
)
ConvLayouts
)
{
{
...
@@ -185,7 +185,7 @@ void profile_convnd_instances_impl(const ck::utils::conv::ConvParams& params,
...
@@ -185,7 +185,7 @@ void profile_convnd_instances_impl(const ck::utils::conv::ConvParams& params,
reference_conv_fwd_fun
);
reference_conv_fwd_fun
);
auto
best_conf
=
run_engine
.
Profile
(
auto
best_conf
=
run_engine
.
Profile
(
conv
::
ConvolutionFwdInstances
<
InDataType
,
WeiDataType
,
OutDataType
>::
template
Get
<
NDim
>(),
conv
::
ConvolutionFwdInstances
<
InDataType
,
WeiDataType
,
OutDataType
>::
template
Get
<
NDim
>(),
nrepeat
,
time_kernel
,
do_verification
,
do_verification
,
do_log
);
do_log
);
...
@@ -201,7 +201,7 @@ void profile_convnd_instances(ConvDataType data_type,
...
@@ -201,7 +201,7 @@ void profile_convnd_instances(ConvDataType data_type,
const
ck
::
utils
::
conv
::
ConvParams
&
params
,
const
ck
::
utils
::
conv
::
ConvParams
&
params
,
bool
do_verification
,
bool
do_verification
,
bool
do_log
,
bool
do_log
,
int
nrepeat
,
bool
time_kernel
,
int
init_method
)
int
init_method
)
{
{
switch
(
data_layout
)
switch
(
data_layout
)
...
@@ -214,7 +214,7 @@ void profile_convnd_instances(ConvDataType data_type,
...
@@ -214,7 +214,7 @@ void profile_convnd_instances(ConvDataType data_type,
params
,
params
,
do_verification
,
do_verification
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
init_method
,
init_method
,
ConvolutionLayouts
<
NDim
,
ConvDataLayout
::
NHWC
>
{});
ConvolutionLayouts
<
NDim
,
ConvDataLayout
::
NHWC
>
{});
break
;
break
;
...
@@ -223,7 +223,7 @@ void profile_convnd_instances(ConvDataType data_type,
...
@@ -223,7 +223,7 @@ void profile_convnd_instances(ConvDataType data_type,
params
,
params
,
do_verification
,
do_verification
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
init_method
,
init_method
,
ConvolutionLayouts
<
NDim
,
ConvDataLayout
::
NHWC
>
{});
ConvolutionLayouts
<
NDim
,
ConvDataLayout
::
NHWC
>
{});
break
;
break
;
...
@@ -232,7 +232,7 @@ void profile_convnd_instances(ConvDataType data_type,
...
@@ -232,7 +232,7 @@ void profile_convnd_instances(ConvDataType data_type,
params
,
params
,
do_verification
,
do_verification
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
init_method
,
init_method
,
ConvolutionLayouts
<
NDim
,
ConvDataLayout
::
NHWC
>
{});
ConvolutionLayouts
<
NDim
,
ConvDataLayout
::
NHWC
>
{});
break
;
break
;
...
@@ -241,7 +241,7 @@ void profile_convnd_instances(ConvDataType data_type,
...
@@ -241,7 +241,7 @@ void profile_convnd_instances(ConvDataType data_type,
params
,
params
,
do_verification
,
do_verification
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
init_method
,
init_method
,
ConvolutionLayouts
<
NDim
,
ConvDataLayout
::
NHWC
>
{});
ConvolutionLayouts
<
NDim
,
ConvDataLayout
::
NHWC
>
{});
break
;
break
;
...
@@ -256,7 +256,7 @@ void profile_convnd_instances(ConvDataType data_type,
...
@@ -256,7 +256,7 @@ void profile_convnd_instances(ConvDataType data_type,
params
,
params
,
do_verification
,
do_verification
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
init_method
,
init_method
,
ConvolutionLayouts
<
NDim
,
ConvDataLayout
::
NCHW
>
{});
ConvolutionLayouts
<
NDim
,
ConvDataLayout
::
NCHW
>
{});
break
;
break
;
...
@@ -265,7 +265,7 @@ void profile_convnd_instances(ConvDataType data_type,
...
@@ -265,7 +265,7 @@ void profile_convnd_instances(ConvDataType data_type,
params
,
params
,
do_verification
,
do_verification
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
init_method
,
init_method
,
ConvolutionLayouts
<
NDim
,
ConvDataLayout
::
NCHW
>
{});
ConvolutionLayouts
<
NDim
,
ConvDataLayout
::
NCHW
>
{});
break
;
break
;
...
@@ -274,7 +274,7 @@ void profile_convnd_instances(ConvDataType data_type,
...
@@ -274,7 +274,7 @@ void profile_convnd_instances(ConvDataType data_type,
params
,
params
,
do_verification
,
do_verification
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
init_method
,
init_method
,
ConvolutionLayouts
<
NDim
,
ConvDataLayout
::
NCHW
>
{});
ConvolutionLayouts
<
NDim
,
ConvDataLayout
::
NCHW
>
{});
break
;
break
;
...
@@ -283,7 +283,7 @@ void profile_convnd_instances(ConvDataType data_type,
...
@@ -283,7 +283,7 @@ void profile_convnd_instances(ConvDataType data_type,
params
,
params
,
do_verification
,
do_verification
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
init_method
,
init_method
,
ConvolutionLayouts
<
NDim
,
ConvDataLayout
::
NCHW
>
{});
ConvolutionLayouts
<
NDim
,
ConvDataLayout
::
NCHW
>
{});
break
;
break
;
...
@@ -304,7 +304,7 @@ int ck::profiler::profile_convnd_fwd(int argc, char* argv[])
...
@@ -304,7 +304,7 @@ int ck::profiler::profile_convnd_fwd(int argc, char* argv[])
bool
do_verification
{
true
};
bool
do_verification
{
true
};
int
init_method
{
2
};
int
init_method
{
2
};
bool
do_log
{
false
};
bool
do_log
{
false
};
int
nrepeat
{
100
};
bool
time_kernel
{
false
};
int
num_dim_spatial
{
2
};
int
num_dim_spatial
{
2
};
ConvParams
params
;
ConvParams
params
;
...
@@ -318,7 +318,7 @@ int ck::profiler::profile_convnd_fwd(int argc, char* argv[])
...
@@ -318,7 +318,7 @@ int ck::profiler::profile_convnd_fwd(int argc, char* argv[])
do_verification
=
std
::
stoi
(
argv
[
4
]);
do_verification
=
std
::
stoi
(
argv
[
4
]);
init_method
=
std
::
stoi
(
argv
[
5
]);
init_method
=
std
::
stoi
(
argv
[
5
]);
do_log
=
std
::
stoi
(
argv
[
6
]);
do_log
=
std
::
stoi
(
argv
[
6
]);
nrepeat
=
std
::
stoi
(
argv
[
7
]);
time_kernel
=
std
::
stoi
(
argv
[
7
]);
num_dim_spatial
=
std
::
stoi
(
argv
[
8
]);
num_dim_spatial
=
std
::
stoi
(
argv
[
8
]);
}
}
if
(
argc
>=
10
)
if
(
argc
>=
10
)
...
@@ -332,20 +332,20 @@ int ck::profiler::profile_convnd_fwd(int argc, char* argv[])
...
@@ -332,20 +332,20 @@ int ck::profiler::profile_convnd_fwd(int argc, char* argv[])
{
{
case
1
:
case
1
:
profile_convnd_instances
<
1
>
(
profile_convnd_instances
<
1
>
(
data_type
,
data_layout
,
params
,
do_verification
,
do_log
,
nrepeat
,
init_method
);
data_type
,
data_layout
,
params
,
do_verification
,
do_log
,
time_kernel
,
init_method
);
break
;
break
;
case
2
:
case
2
:
profile_convnd_instances
<
2
>
(
profile_convnd_instances
<
2
>
(
data_type
,
data_layout
,
params
,
do_verification
,
do_log
,
nrepeat
,
init_method
);
data_type
,
data_layout
,
params
,
do_verification
,
do_log
,
time_kernel
,
init_method
);
break
;
break
;
case
3
:
case
3
:
profile_convnd_instances
<
3
>
(
profile_convnd_instances
<
3
>
(
data_type
,
data_layout
,
params
,
do_verification
,
do_log
,
nrepeat
,
init_method
);
data_type
,
data_layout
,
params
,
do_verification
,
do_log
,
time_kernel
,
init_method
);
break
;
break
;
default:
default:
throw
std
::
runtime_error
(
"profile_conv_fwd: unsupported num_dim_spatial value: "
+
throw
std
::
runtime_error
(
"profile_conv_fwd: unsupported num_dim_spatial value: "
+
std
::
to_string
(
num_dim_spatial
));
std
::
to_string
(
num_dim_spatial
));
}
}
return
1
;
return
0
;
}
}
profiler/src/profile_gemm.cpp
View file @
f9c478e2
...
@@ -38,8 +38,8 @@ int profile_gemm(int argc, char* argv[])
...
@@ -38,8 +38,8 @@ int profile_gemm(int argc, char* argv[])
printf
(
" 3: A[k, m] * B[n, k] = C[m, n])
\n
"
);
printf
(
" 3: A[k, m] * B[n, k] = C[m, n])
\n
"
);
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg
8
: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg
6
: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg7:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg8 to 13: M, N, K, StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg8 to 13: M, N, K, StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg14: split k into mulitiple batch
\n
"
);
printf
(
"arg14: split k into mulitiple batch
\n
"
);
exit
(
1
);
exit
(
1
);
...
@@ -50,7 +50,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -50,7 +50,7 @@ int profile_gemm(int argc, char* argv[])
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
7
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
7
]);
const
int
M
=
std
::
stoi
(
argv
[
8
]);
const
int
M
=
std
::
stoi
(
argv
[
8
]);
const
int
N
=
std
::
stoi
(
argv
[
9
]);
const
int
N
=
std
::
stoi
(
argv
[
9
]);
...
@@ -68,13 +68,14 @@ int profile_gemm(int argc, char* argv[])
...
@@ -68,13 +68,14 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -88,13 +89,14 @@ int profile_gemm(int argc, char* argv[])
...
@@ -88,13 +89,14 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -108,13 +110,14 @@ int profile_gemm(int argc, char* argv[])
...
@@ -108,13 +110,14 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -128,13 +131,14 @@ int profile_gemm(int argc, char* argv[])
...
@@ -128,13 +131,14 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -146,6 +150,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -146,6 +150,7 @@ int profile_gemm(int argc, char* argv[])
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
ck
::
profiler
::
profile_gemm_impl
<
float
,
ck
::
profiler
::
profile_gemm_impl
<
float
,
float
,
float
,
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
...
@@ -154,7 +159,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -154,7 +159,7 @@ int profile_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -166,6 +171,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -166,6 +171,7 @@ int profile_gemm(int argc, char* argv[])
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
ck
::
profiler
::
profile_gemm_impl
<
float
,
ck
::
profiler
::
profile_gemm_impl
<
float
,
float
,
float
,
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
...
@@ -174,7 +180,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -174,7 +180,7 @@ int profile_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -186,6 +192,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -186,6 +192,7 @@ int profile_gemm(int argc, char* argv[])
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
{
ck
::
profiler
::
profile_gemm_impl
<
float
,
ck
::
profiler
::
profile_gemm_impl
<
float
,
float
,
float
,
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
...
@@ -194,7 +201,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -194,7 +201,7 @@ int profile_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -206,6 +213,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -206,6 +213,7 @@ int profile_gemm(int argc, char* argv[])
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
{
ck
::
profiler
::
profile_gemm_impl
<
float
,
ck
::
profiler
::
profile_gemm_impl
<
float
,
float
,
float
,
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
...
@@ -214,7 +222,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -214,7 +222,7 @@ int profile_gemm(int argc, char* argv[])
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -228,13 +236,14 @@ int profile_gemm(int argc, char* argv[])
...
@@ -228,13 +236,14 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
int32_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -248,13 +257,14 @@ int profile_gemm(int argc, char* argv[])
...
@@ -248,13 +257,14 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
int32_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -268,13 +278,14 @@ int profile_gemm(int argc, char* argv[])
...
@@ -268,13 +278,14 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
int32_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -288,13 +299,14 @@ int profile_gemm(int argc, char* argv[])
...
@@ -288,13 +299,14 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
int8_t
,
int8_t
,
int32_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -308,13 +320,14 @@ int profile_gemm(int argc, char* argv[])
...
@@ -308,13 +320,14 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -328,13 +341,14 @@ int profile_gemm(int argc, char* argv[])
...
@@ -328,13 +341,14 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -348,13 +362,14 @@ int profile_gemm(int argc, char* argv[])
...
@@ -348,13 +362,14 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -368,13 +383,14 @@ int profile_gemm(int argc, char* argv[])
...
@@ -368,13 +383,14 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
time_kernel
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -388,5 +404,5 @@ int profile_gemm(int argc, char* argv[])
...
@@ -388,5 +404,5 @@ int profile_gemm(int argc, char* argv[])
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
}
}
return
1
;
return
0
;
}
}
Prev
1
…
11
12
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment