Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dd6a8de4
Commit
dd6a8de4
authored
Apr 06, 2022
by
Jehandad Khan
Browse files
Merge branch 'develop' into jd/dev_pkg
parents
0aa899aa
abf4bdb9
Changes
470
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1918 additions
and
246 deletions
+1918
-246
profiler/include/profile_gemm_bias_2d_impl.hpp
profiler/include/profile_gemm_bias_2d_impl.hpp
+5
-3
profiler/include/profile_gemm_bias_relu_add_impl.hpp
profiler/include/profile_gemm_bias_relu_add_impl.hpp
+3
-1
profiler/include/profile_gemm_bias_relu_impl.hpp
profiler/include/profile_gemm_bias_relu_impl.hpp
+4
-2
profiler/include/profile_gemm_impl.hpp
profiler/include/profile_gemm_impl.hpp
+82
-10
profiler/include/profile_gemm_reduce_impl.hpp
profiler/include/profile_gemm_reduce_impl.hpp
+335
-0
profiler/include/profile_grouped_gemm_impl.hpp
profiler/include/profile_grouped_gemm_impl.hpp
+316
-0
profiler/include/profile_reduce_impl.hpp
profiler/include/profile_reduce_impl.hpp
+225
-167
profiler/src/profile_batched_gemm.cpp
profiler/src/profile_batched_gemm.cpp
+245
-7
profiler/src/profile_batched_gemm_reduce.cpp
profiler/src/profile_batched_gemm_reduce.cpp
+153
-0
profiler/src/profile_conv_bwd_data.cpp
profiler/src/profile_conv_bwd_data.cpp
+12
-8
profiler/src/profile_conv_bwd_weight.cpp
profiler/src/profile_conv_bwd_weight.cpp
+146
-0
profiler/src/profile_conv_fwd.cpp
profiler/src/profile_conv_fwd.cpp
+8
-8
profiler/src/profile_conv_fwd_bias_relu.cpp
profiler/src/profile_conv_fwd_bias_relu.cpp
+8
-8
profiler/src/profile_conv_fwd_bias_relu_add.cpp
profiler/src/profile_conv_fwd_bias_relu_add.cpp
+8
-8
profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp
profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp
+8
-8
profiler/src/profile_convnd_bwd_data.cpp
profiler/src/profile_convnd_bwd_data.cpp
+224
-0
profiler/src/profile_gemm.cpp
profiler/src/profile_gemm.cpp
+124
-4
profiler/src/profile_gemm_bias_2d.cpp
profiler/src/profile_gemm_bias_2d.cpp
+4
-4
profiler/src/profile_gemm_bias_relu.cpp
profiler/src/profile_gemm_bias_relu.cpp
+4
-4
profiler/src/profile_gemm_bias_relu_add.cpp
profiler/src/profile_gemm_bias_relu_add.cpp
+4
-4
No files found.
profiler/include/profile_gemm_bias_2d_impl.hpp
View file @
dd6a8de4
#pragma once
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
...
...
@@ -7,7 +9,7 @@
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "element_wise_operation.hpp"
#include "device_gemm.hpp"
#include "device_gemm
_bias
.hpp"
#include "reference_gemm_bias_2d.hpp"
namespace
ck
{
...
...
@@ -98,7 +100,7 @@ void profile_gemm_bias_2d_impl(int do_verification,
std
::
cout
<<
"c0_m_n: "
<<
c0_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
()
;
std
::
size_t
num_thread
=
1
;
switch
(
init_method
)
{
case
0
:
break
;
...
...
@@ -283,7 +285,7 @@ void profile_gemm_bias_2d_impl(int do_verification,
{
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
check_err
or
(
c_m_n_
host
_result
,
c_m_n_
device
_result
);
ck
::
utils
::
check_err
(
c_m_n_
device
_result
.
mData
,
c_m_n_
host
_result
.
mData
);
if
(
do_log
)
{
...
...
profiler/include/profile_gemm_bias_relu_add_impl.hpp
View file @
dd6a8de4
#pragma once
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
...
...
@@ -257,7 +259,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
{
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
check_err
or
(
c_m_n_
host
_result
,
c_m_n_
device
_result
);
ck
::
utils
::
check_err
(
c_m_n_
device
_result
.
mData
,
c_m_n_
host
_result
.
mData
);
if
(
do_log
)
{
...
...
profiler/include/profile_gemm_bias_relu_impl.hpp
View file @
dd6a8de4
#pragma once
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
...
...
@@ -83,7 +85,7 @@ void profile_gemm_bias_relu_impl(int do_verification,
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c0_n: "
<<
c0_n
.
mDesc
<<
std
::
endl
;
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
()
;
std
::
size_t
num_thread
=
1
;
switch
(
init_method
)
{
case
0
:
break
;
...
...
@@ -236,7 +238,7 @@ void profile_gemm_bias_relu_impl(int do_verification,
{
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
check_err
or
(
c_m_n_
host
_result
,
c_m_n_
device
_result
);
ck
::
utils
::
check_err
(
c_m_n_
device
_result
.
mData
,
c_m_n_
host
_result
.
mData
);
if
(
do_log
)
{
...
...
profiler/include/profile_gemm_impl.hpp
View file @
dd6a8de4
#pragma once
#include <iomanip>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
...
...
@@ -26,16 +28,28 @@ void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNo
void
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
...
...
@@ -45,6 +59,11 @@ void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNo
void
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
...
...
@@ -103,7 +122,7 @@ void profile_gemm_impl(int do_verification,
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_device_result
.
mDesc
<<
std
::
endl
;
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
()
;
std
::
size_t
num_thread
=
1
;
switch
(
init_method
)
{
case
0
:
break
;
...
...
@@ -127,11 +146,6 @@ void profile_gemm_impl(int do_verification,
const
auto
b_element_op
=
BElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
// if(do_verification)
// {
// }
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
());
...
...
@@ -159,6 +173,9 @@ void profile_gemm_impl(int do_verification,
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances
(
gemm_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
...
...
@@ -174,6 +191,9 @@ void profile_gemm_impl(int do_verification,
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances
(
gemm_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
...
...
@@ -189,6 +209,9 @@ void profile_gemm_impl(int do_verification,
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
...
...
@@ -204,6 +227,9 @@ void profile_gemm_impl(int do_verification,
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances
(
gemm_ptrs
);
}
}
}
...
...
@@ -291,23 +317,65 @@ void profile_gemm_impl(int do_verification,
is_same
<
CDataType
,
ck
::
bhalf_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances
(
gemm_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ADataType
,
int8_t
>::
value
&&
is_same
<
BDataType
,
int8_t
>::
value
&&
is_same
<
CDataType
,
int8_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances
(
gemm_ptrs
);
}
}
if
(
gemm_ptrs
.
size
()
<=
0
)
...
...
@@ -342,6 +410,10 @@ void profile_gemm_impl(int do_verification,
if
(
gemm_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
// re-init C to zero before profiling next kernel
c_m_n_device_result
.
GenerateTensorValue
(
GeneratorTensor_0
<
CDataType
>
{},
num_thread
);
c_device_buf
.
ToDevice
(
c_m_n_device_result
.
mData
.
data
());
std
::
string
gemm_name
=
gemm_ptr
->
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
nrepeat
);
...
...
@@ -400,7 +472,7 @@ void profile_gemm_impl(int do_verification,
ref_invoker
.
Run
(
ref_argument
);
check_err
or
(
c_m_n_
host
_result
,
c_m_n_
device_f32
_result
);
ck
::
utils
::
check_err
(
c_m_n_
device_f32
_result
.
mData
,
c_m_n_
host
_result
.
mData
);
if
(
do_log
)
{
...
...
@@ -429,7 +501,7 @@ void profile_gemm_impl(int do_verification,
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
check_err
or
(
c_m_n_
host
_result
,
c_m_n_
device
_result
);
ck
::
utils
::
check_err
(
c_m_n_
device
_result
.
mData
,
c_m_n_
host
_result
.
mData
);
if
(
do_log
)
{
...
...
profiler/include/profile_gemm_reduce_impl.hpp
0 → 100644
View file @
dd6a8de4
#pragma once
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_conv.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "element_wise_operation.hpp"
#include "element_wise_reduce_operation.hpp"
#include "device_gemm_reduce.hpp"
#include "reference_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
DeviceGemmReduceNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmReducePtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
ReduceSum
,
ck
::
tensor_operation
::
element_wise
::
ReduceSquareSum
>
;
void
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmReduceNoOpPtr
>&
);
void
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmReduceNoOpPtr
>&
);
void
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmReduceNoOpPtr
>&
);
void
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmReduceNoOpPtr
>&
);
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
namespace
ck
{
namespace
profiler
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
DDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
bool
profile_gemm_reduce_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
int
nrepeat
,
int
M
,
int
N
,
int
K
,
int
StrideA
,
int
StrideB
,
int
StrideC
)
{
bool
pass
=
true
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
is_same
<
decltype
(
layout
),
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
DDataType
>
d0_m_host_result
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
M
)})));
Tensor
<
DDataType
>
d1_m_host_result
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
M
)})));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
DDataType
>
d0_m_device_result
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
M
)})));
Tensor
<
DDataType
>
d1_m_device_result
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
M
)})));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d0_m: "
<<
d0_m_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d1_m: "
<<
d1_m_host_result
.
mDesc
<<
std
::
endl
;
std
::
size_t
num_thread
=
1
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
std
::
srand
(
0
);
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
},
num_thread
);
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
},
num_thread
);
break
;
default:
std
::
srand
(
0
);
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
},
num_thread
);
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
},
num_thread
);
}
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
D0ReduceOp
=
ck
::
tensor_operation
::
element_wise
::
ReduceSum
;
using
D1ReduceOp
=
ck
::
tensor_operation
::
element_wise
::
ReduceSquareSum
;
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
const
auto
d0_reduce_op
=
D0ReduceOp
{};
const
auto
d1_reduce_op
=
D1ReduceOp
{};
if
(
do_verification
)
{
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
float
d0_acc
=
d0_reduce_op
.
GetReduceZeroValue
();
float
d1_acc
=
d1_reduce_op
.
GetReduceZeroValue
();
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
d0_reduce_op
.
Reduce
(
d0_acc
,
c_m_n_host_result
(
m
,
n
));
d1_reduce_op
.
Reduce
(
d1_acc
,
c_m_n_host_result
(
m
,
n
));
}
d0_m_host_result
(
m
)
=
d0_acc
;
d1_m_host_result
(
m
)
=
d1_acc
;
}
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
d0_device_buf
(
sizeof
(
DDataType
)
*
d0_m_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
d1_device_buf
(
sizeof
(
DDataType
)
*
d1_m_device_result
.
mDesc
.
GetElementSpace
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmReduceNoOpPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances
(
gemm_ptrs
);
}
}
if
(
gemm_ptrs
.
size
()
<=
0
)
{
throw
std
::
runtime_error
(
"wrong! no device GEMM instance found"
);
}
std
::
string
best_gemm_name
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
// profile device GEMM instances
for
(
auto
&
gemm_ptr
:
gemm_ptrs
)
{
auto
argument_ptr
=
gemm_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
d0_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
d1_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
d0_reduce_op
,
d1_reduce_op
);
auto
invoker_ptr
=
gemm_ptr
->
MakeInvokerPointer
();
if
(
gemm_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
// warm up
invoker_ptr
->
Run
(
argument_ptr
.
get
());
// timing
float
total_time
=
0
;
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
// init DO, D1 to 0
d0_device_buf
.
SetZero
();
d1_device_buf
.
SetZero
();
KernelTimer
timer
;
timer
.
Start
();
invoker_ptr
->
Run
(
argument_ptr
.
get
());
timer
.
End
();
total_time
+=
timer
.
GetElapsedTime
();
}
float
ave_time
=
total_time
/
nrepeat
;
std
::
string
gemm_name
=
gemm_ptr
->
GetTypeString
();
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
M
+
sizeof
(
CDataType
)
*
M
*
N
+
sizeof
(
CDataType
)
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
best_gemm_name
=
gemm_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
if
(
do_verification
)
{
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
d0_device_buf
.
FromDevice
(
d0_m_device_result
.
mData
.
data
());
d1_device_buf
.
FromDevice
(
d1_m_device_result
.
mData
.
data
());
float
c_error
=
check_error
(
c_m_n_host_result
,
c_m_n_device_result
);
float
d0_error
=
check_error
(
d0_m_host_result
,
d0_m_device_result
);
float
d1_error
=
check_error
(
d1_m_host_result
,
d1_m_device_result
);
pass
=
pass
&&
(
c_error
<
1E-6
);
pass
=
pass
&&
(
d0_error
<
1E-6
);
pass
=
pass
&&
(
d1_error
<
1E-6
);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a_m_k
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b_k_n
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host: "
,
c_m_n_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
c_m_n_device_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"d0_host: "
,
d0_m_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"d0_device: "
,
d0_m_device_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"d1_host: "
,
d1_m_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"d1_device: "
,
d1_m_device_result
.
mData
,
","
)
<<
std
::
endl
;
}
}
}
else
{
std
::
cout
<<
"does not support this GEMM problem"
<<
std
::
endl
;
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_gemm_name
<<
std
::
endl
;
return
pass
;
}
}
// namespace profiler
}
// namespace ck
profiler/include/profile_grouped_gemm_impl.hpp
0 → 100644
View file @
dd6a8de4
#pragma once
#include <iomanip>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_conv.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "element_wise_operation.hpp"
#include "device_gemm.hpp"
#include "reference_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_grouped_gemm_instance
{
using
DeviceGroupedGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
}
// namespace device_grouped_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
namespace
ck
{
namespace
profiler
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
void
profile_grouped_gemm_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
int
nrepeat
,
std
::
vector
<
int
>
Ms
,
std
::
vector
<
int
>
Ns
,
std
::
vector
<
int
>
Ks
,
std
::
vector
<
int
>
StrideAs
,
std
::
vector
<
int
>
StrideBs
,
std
::
vector
<
int
>
StrideCs
)
{
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
is_same
<
decltype
(
layout
),
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
int
group_count
=
Ms
.
size
();
if
(
!
(
group_count
==
Ns
.
size
()
&&
group_count
==
Ks
.
size
()
&&
group_count
==
StrideAs
.
size
()
&&
group_count
==
StrideBs
.
size
()
&&
group_count
==
StrideCs
.
size
()))
{
throw
std
::
runtime_error
(
"wrong! inconsistent M/N/Ks, StrideA/B/Cs size
\n
"
);
}
std
::
vector
<
Tensor
<
ADataType
>>
a_m_k
;
std
::
vector
<
Tensor
<
BDataType
>>
b_k_n
;
std
::
vector
<
Tensor
<
CDataType
>>
c_m_n_device_results
;
for
(
int
i
=
0
;
i
<
Ms
.
size
();
i
++
)
{
a_m_k
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ks
[
i
],
StrideAs
[
i
],
ALayout
{})));
b_k_n
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
Ks
[
i
],
Ns
[
i
],
StrideBs
[
i
],
BLayout
{})));
c_m_n_device_results
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
std
::
cout
<<
"group: "
<<
i
<<
" a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
", b_k_n["
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
", c_m_n_device_results["
<<
i
<<
"]:"
<<
c_m_n_device_results
[
i
].
mDesc
<<
std
::
endl
;
std
::
size_t
num_thread
=
1
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_m_k
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
},
num_thread
);
b_k_n
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
},
num_thread
);
break
;
default:
a_m_k
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
},
num_thread
);
b_k_n
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
},
num_thread
);
}
c_m_n_device_results
[
i
].
GenerateTensorValue
(
GeneratorTensor_0
<
CDataType
>
{},
num_thread
);
}
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
// if(do_verification)
// {
// }
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
std
::
vector
<
DeviceMemPtr
>
a_device_buf
,
b_device_buf
,
c_device_buf
;
a_device_buf
.
reserve
(
group_count
);
b_device_buf
.
reserve
(
group_count
);
c_device_buf
.
reserve
(
group_count
);
std
::
vector
<
const
void
*>
p_a
,
p_b
;
std
::
vector
<
void
*>
p_c
;
p_a
.
reserve
(
group_count
);
p_b
.
reserve
(
group_count
);
p_c
.
reserve
(
group_count
);
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmShape
>
gemm_shapes
;
gemm_shapes
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
a_device_buf
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_m_k
[
i
].
mDesc
.
GetElementSpace
()));
b_device_buf
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
b_k_n
[
i
].
mDesc
.
GetElementSpace
()));
c_device_buf
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
CDataType
)
*
c_m_n_device_results
[
i
].
mDesc
.
GetElementSpace
()));
a_device_buf
[
i
]
->
ToDevice
(
a_m_k
[
i
].
mData
.
data
());
b_device_buf
[
i
]
->
ToDevice
(
b_k_n
[
i
].
mData
.
data
());
c_device_buf
[
i
]
->
ToDevice
(
c_m_n_device_results
[
i
].
mData
.
data
());
gemm_shapes
.
push_back
({
Ms
[
i
],
Ns
[
i
],
Ks
[
i
],
StrideAs
[
i
],
StrideBs
[
i
],
StrideCs
[
i
]});
p_a
.
push_back
(
a_device_buf
[
i
]
->
GetDeviceBuffer
());
p_b
.
push_back
(
b_device_buf
[
i
]
->
GetDeviceBuffer
());
p_c
.
push_back
(
c_device_buf
[
i
]
->
GetDeviceBuffer
());
}
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_grouped_gemm_instance
::
DeviceGroupedGemmNoOpPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_grouped_gemm_instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_grouped_gemm_instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_grouped_gemm_instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_grouped_gemm_instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
gemm_ptrs
);
}
}
if
(
gemm_ptrs
.
size
()
<=
0
)
{
throw
std
::
runtime_error
(
"wrong! no device GEMM instance found"
);
}
std
::
string
best_gemm_name
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
// profile device GEMM instances
for
(
auto
&
gemm_ptr
:
gemm_ptrs
)
{
auto
argument_ptr
=
gemm_ptr
->
MakeArgumentPointer
(
p_a
,
p_b
,
p_c
,
gemm_shapes
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
auto
invoker_ptr
=
gemm_ptr
->
MakeInvokerPointer
();
if
(
gemm_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
std
::
string
gemm_name
=
gemm_ptr
->
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
nrepeat
);
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
flop
+=
std
::
size_t
(
2
)
*
Ms
[
i
]
*
Ns
[
i
]
*
Ks
[
i
];
num_btype
+=
sizeof
(
ADataType
)
*
Ms
[
i
]
*
Ks
[
i
]
+
sizeof
(
BDataType
)
*
Ks
[
i
]
*
Ns
[
i
]
+
sizeof
(
CDataType
)
*
Ms
[
i
]
*
Ns
[
i
];
}
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
best_gemm_name
=
gemm_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
if
(
do_verification
)
{
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
c_device_buf
[
i
]
->
FromDevice
(
c_m_n_device_results
[
i
].
mData
.
data
());
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{}));
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
[
i
],
b_k_n
[
i
],
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
ck
::
utils
::
check_err
(
c_m_n_device_results
[
i
].
mData
,
c_m_n_host_result
.
mData
);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a_m_k
[
i
].
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b_k_n
[
i
].
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
c_m_n_device_results
[
i
].
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host : "
,
c_m_n_host_result
.
mData
,
","
)
<<
std
::
endl
;
}
}
}
}
else
{
std
::
cout
<<
"does not support this GEMM problem"
<<
std
::
endl
;
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_gemm_name
<<
std
::
endl
;
}
// namespace profiler
}
// namespace profiler
}
// namespace ck
profiler/include/profile_reduce_impl.hpp
View file @
dd6a8de4
#pragma once
#include "check_err.hpp"
#include "device_reduce.hpp"
#include "device_reduce_instance.hpp"
#include "reduction_enums.hpp"
#include "host_
generic_
reduction.hpp"
#include "host_reduction.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_reduce_instance
{
template
<
int
Rank
,
typename
ReduceDim
s
,
int
ReduceOpId
,
int
NanOpt
,
int
IndicesOpt
>
template
<
int
Rank
,
int
Num
ReduceDim
,
int
ReduceOpId
,
int
NanOpt
,
int
IndicesOpt
>
struct
ReduceDescription
{
static
constexpr
int
Rank_
=
Rank
;
static
constexpr
int
NumReduceDim_
=
NumReduceDim
;
static
constexpr
int
ReduceOpId_
=
ReduceOpId
;
static
constexpr
int
NanOpt_
=
NanOpt
;
static
constexpr
int
IndicesOpt_
=
IndicesOpt
;
using
ReduceDims_
=
ReduceDims
;
};
using
reduce_description_instances
=
std
::
tuple
<
ReduceDescription
<
4
,
Sequence
<
0
,
1
,
2
>
,
0
,
0
,
0
>
,
// for ADD
ReduceDescription
<
4
,
Sequence
<
0
>
,
0
,
0
,
0
>
,
ReduceDescription
<
2
,
Sequence
<
1
>
,
0
,
0
,
0
>
,
ReduceDescription
<
4
,
Sequence
<
0
,
1
,
2
>
,
5
,
0
,
0
>
,
// for AVG
ReduceDescription
<
4
,
Sequence
<
0
>
,
5
,
0
,
0
>
,
ReduceDescription
<
2
,
Sequence
<
1
>
,
5
,
0
,
0
>
,
ReduceDescription
<
4
,
Sequence
<
0
,
1
,
2
>
,
7
,
0
,
0
>
,
// for NORM2
ReduceDescription
<
4
,
Sequence
<
0
>
,
7
,
0
,
0
>
,
ReduceDescription
<
2
,
Sequence
<
1
>
,
7
,
0
,
0
>
,
ReduceDescription
<
4
,
Sequence
<
0
,
1
,
2
>
,
2
,
0
,
0
>
,
// for MIN
ReduceDescription
<
4
,
Sequence
<
0
>
,
2
,
0
,
0
>
,
ReduceDescription
<
2
,
Sequence
<
1
>
,
2
,
0
,
0
>
,
ReduceDescription
<
4
,
Sequence
<
0
,
1
,
2
>
,
3
,
0
,
0
>
,
// for MAX
ReduceDescription
<
4
,
Sequence
<
0
>
,
3
,
0
,
0
>
,
ReduceDescription
<
2
,
Sequence
<
1
>
,
3
,
0
,
0
>
,
ReduceDescription
<
4
,
Sequence
<
0
,
1
,
2
>
,
4
,
0
,
0
>
,
// for AMAX
ReduceDescription
<
4
,
Sequence
<
0
>
,
4
,
0
,
0
>
,
ReduceDescription
<
2
,
Sequence
<
1
>
,
4
,
0
,
0
>
,
ReduceDescription
<
4
,
Sequence
<
0
,
1
,
2
>
,
2
,
0
,
1
>
,
// for MIN
ReduceDescription
<
4
,
Sequence
<
0
>
,
2
,
0
,
1
>
,
ReduceDescription
<
2
,
Sequence
<
1
>
,
2
,
0
,
1
>
,
ReduceDescription
<
4
,
Sequence
<
0
,
1
,
2
>
,
3
,
0
,
1
>
,
// for MAX
ReduceDescription
<
4
,
Sequence
<
0
>
,
3
,
0
,
1
>
,
ReduceDescription
<
2
,
Sequence
<
1
>
,
3
,
0
,
1
>
,
ReduceDescription
<
4
,
Sequence
<
0
,
1
,
2
>
,
4
,
0
,
1
>
,
// for AMAX
ReduceDescription
<
4
,
Sequence
<
0
>
,
4
,
0
,
1
>
,
ReduceDescription
<
2
,
Sequence
<
1
>
,
4
,
0
,
1
>>
;
using
reduce_description_instances
=
std
::
tuple
<
ReduceDescription
<
4
,
3
,
0
,
0
,
0
>
,
// for ADD
ReduceDescription
<
4
,
4
,
0
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
0
,
0
,
0
>
,
ReduceDescription
<
2
,
1
,
0
,
0
,
0
>
,
ReduceDescription
<
4
,
3
,
5
,
0
,
0
>
,
// for AVG
ReduceDescription
<
4
,
4
,
5
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
5
,
0
,
0
>
,
ReduceDescription
<
2
,
1
,
5
,
0
,
0
>
,
ReduceDescription
<
4
,
3
,
7
,
0
,
0
>
,
// for NORM2
ReduceDescription
<
4
,
4
,
7
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
7
,
0
,
0
>
,
ReduceDescription
<
2
,
1
,
7
,
0
,
0
>
,
ReduceDescription
<
4
,
3
,
2
,
0
,
0
>
,
// for MIN
ReduceDescription
<
4
,
4
,
2
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
2
,
0
,
0
>
,
ReduceDescription
<
2
,
1
,
2
,
0
,
0
>
,
ReduceDescription
<
4
,
3
,
3
,
0
,
0
>
,
// for MAX
ReduceDescription
<
4
,
4
,
3
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
3
,
0
,
0
>
,
ReduceDescription
<
2
,
1
,
3
,
0
,
0
>
,
ReduceDescription
<
4
,
3
,
4
,
0
,
0
>
,
// for AMAX
ReduceDescription
<
4
,
4
,
4
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
4
,
0
,
0
>
,
ReduceDescription
<
2
,
1
,
4
,
0
,
0
>
,
ReduceDescription
<
4
,
3
,
2
,
0
,
1
>
,
// for MIN
ReduceDescription
<
4
,
4
,
2
,
0
,
1
>
,
ReduceDescription
<
4
,
1
,
2
,
0
,
1
>
,
ReduceDescription
<
2
,
1
,
2
,
0
,
1
>
,
ReduceDescription
<
4
,
3
,
3
,
0
,
1
>
,
// for MAX
ReduceDescription
<
4
,
4
,
3
,
0
,
1
>
,
ReduceDescription
<
4
,
1
,
3
,
0
,
1
>
,
ReduceDescription
<
2
,
1
,
3
,
0
,
1
>
,
ReduceDescription
<
4
,
3
,
4
,
0
,
1
>
,
// for AMAX
ReduceDescription
<
4
,
4
,
4
,
0
,
1
>
,
ReduceDescription
<
4
,
1
,
4
,
0
,
1
>
,
ReduceDescription
<
2
,
1
,
4
,
0
,
1
>>
;
template
<
typename
DescriptionType
>
bool
description_match
(
const
DescriptionType
&
description
,
int
Rank
,
const
std
::
vector
<
int
>&
R
educeDims
,
ReduceTensorOp
_t
ReduceOpId
,
NanPropagation
_t
NanOpt
,
ReduceTensorIndices
_t
IndicesOpt
)
const
std
::
vector
<
int
>&
r
educeDims
,
ReduceTensorOp
ReduceOpId
,
NanPropagation
NanOpt
,
ReduceTensorIndices
IndicesOpt
)
{
if
(
description
.
Rank_
!=
Rank
||
description
.
ReduceOpId_
!=
static_cast
<
int
>
(
ReduceOpId
)
||
description
.
NanOpt_
!=
static_cast
<
int
>
(
NanOpt
)
||
description
.
IndicesOpt_
!=
static_cast
<
int
>
(
IndicesOpt
))
return
(
false
);
if
(
DescriptionType
::
ReduceDim
s_
::
Size
()
!=
R
educeDims
.
size
())
if
(
DescriptionType
::
Num
ReduceDim
_
!=
r
educeDims
.
size
())
return
(
false
);
bool
result
=
true
;
static_for
<
0
,
DescriptionType
::
ReduceDims_
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
if
(
DescriptionType
::
ReduceDims_
::
At
(
i
)
!=
ReduceDims
[
i
])
result
=
false
;
});
return
(
result
);
};
...
...
@@ -87,33 +91,29 @@ bool description_match(const DescriptionType& description,
namespace
ck
{
namespace
profiler
{
template
<
int
Rank
,
typename
ReduceDim
s
>
static
std
::
vector
<
int
>
get_reduce
_d
ims
(
)
template
<
in
dex_
t
Rank
,
index_t
Num
ReduceDim
>
static
inline
std
::
vector
<
int
>
get_
invariant_dims
(
const
std
::
vector
<
int
>&
reduce
D
ims
)
{
std
::
vector
<
int
>
resDims
;
assert
(
NumReduceDim
==
reduceDims
.
size
())
;
static_for
<
0
,
ReduceDims
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
resDims
.
push_back
(
ReduceDims
::
At
(
i
));
})
;
int
reduceFlag
=
0
;
return
(
resDims
);
};
template
<
int
Rank
,
typename
ReduceDims
>
static
std
::
vector
<
int
>
get_invariant_dims
()
{
std
::
vector
<
int
>
resDims
;
unsigned
int
incFlag
=
0
;
// flag the bits for the reduceDims
for
(
int
i
=
0
;
i
<
NumReduceDim
;
i
++
)
{
reduceFlag
|=
1
<<
reduceDims
[
i
];
};
static_for
<
0
,
ReduceDims
::
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
incFlag
=
incFlag
|
(
0x1
<<
ReduceDims
::
At
(
i
));
});
std
::
vector
<
int
>
invariantDims
;
for
(
int
dim
=
0
;
dim
<
Rank
;
dim
++
)
// collect invariant dimensions
for
(
int
i
=
0
;
i
<
Rank
;
i
++
)
if
((
reduceFlag
&
(
1
<<
i
))
==
0
)
{
if
(
incFlag
&
(
0x1
<<
dim
))
continue
;
resDims
.
push_back
(
dim
);
invariantDims
.
push_back
(
i
);
};
return
(
res
Dims
)
;
return
invariant
Dims
;
};
template
<
typename
T
>
...
...
@@ -133,32 +133,33 @@ static void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems)
};
// map the data type used by the GPU kernels to the corresponding type used by the host codes
template
<
typename
inData
Type
>
template
<
typename
In
Type
>
struct
type_mapping
{
using
o
ut
Data
Type
=
inData
Type
;
using
O
utType
=
In
Type
;
};
template
<
>
struct
type_mapping
<
ck
::
half_t
>
{
using
o
ut
Data
Type
=
half_float
::
half
;
using
O
utType
=
half_float
::
half
;
};
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
int
Rank
,
typename
ReduceDim
s_
,
ReduceTensorOp
_t
ReduceOpId
,
NanPropagation
_t
NanOpt
,
ReduceTensorIndices
_t
IndicesOpt
>
int
Num
ReduceDim
,
ReduceTensorOp
ReduceOpId
,
NanPropagation
NanOpt
,
ReduceTensorIndices
IndicesOpt
>
void
profile_reduce_impl_impl
(
bool
do_verification
,
int
init_method
,
bool
do_log
,
bool
do_dumpout
,
int
nrepeat
,
const
std
::
vector
<
size_t
>&
inLengths
,
const
std
::
vector
<
int
>&
reduceDims
,
float
alpha
,
float
beta
)
{
...
...
@@ -167,17 +168,17 @@ void profile_reduce_impl_impl(bool do_verification,
using
namespace
ck
::
host_reduce
;
constexpr
bool
op_support_indices
=
(
ReduceOpId
==
ReduceTensorOp
_t
::
MIN
||
ReduceOpId
==
ReduceTensorOp
_t
::
MAX
||
ReduceOpId
==
ReduceTensorOp
_t
::
AMAX
);
(
ReduceOpId
==
ReduceTensorOp
::
MIN
||
ReduceOpId
==
ReduceTensorOp
::
MAX
||
ReduceOpId
==
ReduceTensorOp
::
AMAX
);
constexpr
bool
NeedIndices
=
(
op_support_indices
&&
(
IndicesOpt
!=
ReduceTensorIndices
_t
::
NO_INDICES
));
(
op_support_indices
&&
(
IndicesOpt
!=
ReduceTensorIndices
::
NO_INDICES
));
constexpr
bool
PropagateNan
=
(
NanOpt
==
NanPropagation
_t
::
PROPAGATE_NAN
);
constexpr
bool
PropagateNan
=
(
NanOpt
==
NanPropagation
::
PROPAGATE_NAN
);
constexpr
bool
out_support_atomic_add
=
std
::
is_same
<
OutDataType
,
float
>::
value
;
constexpr
bool
op_support_atomic_add
=
!
op_support_indices
&&
ReduceOpId
!=
ReduceTensorOp
_t
::
NORM2
;
!
op_support_indices
&&
ReduceOpId
!=
ReduceTensorOp
::
NORM2
;
constexpr
bool
use_atomic_add
=
(
out_support_atomic_add
&&
op_support_atomic_add
);
// 1) If InDataType is half_t, must use half_t as AccDataType for indexable reduction operations
...
...
@@ -195,29 +196,47 @@ void profile_reduce_impl_impl(bool do_verification,
// 1) The indices can only be used when the reduction operation is indexable
constexpr
bool
invalid_reduce_3
=
(
!
op_support_indices
&&
IndicesOpt
!=
ReduceTensorIndices_t
::
NO_INDICES
);
(
!
op_support_indices
&&
IndicesOpt
!=
ReduceTensorIndices
::
NO_INDICES
);
// 1) If InDataType is int8_t, must use int8_t as AccDataType for indexable reduction operations
// 2) If InDataType is int8_t, must use int32_t as AccDataType for non-indexable reduction
// operations
constexpr
bool
invalid_reduce_4
=
std
::
is_same
<
InDataType
,
int8_t
>::
value
&&
((
!
op_support_indices
&&
!
std
::
is_same
<
AccDataType
,
int32_t
>::
value
)
||
(
op_support_indices
&&
!
std
::
is_same
<
AccDataType
,
int8_t
>::
value
));
constexpr
bool
invalid_reduce
=
(
invalid_reduce_1
||
invalid_reduce_2
||
invalid_reduce_3
);
// 1) If InDataType is int8_t, the supported operation must be either indexable operations or
// ADD/AVG
constexpr
bool
invalid_reduce_5
=
std
::
is_same
<
InDataType
,
int8_t
>::
value
&&
(
!
op_support_indices
&&
ReduceOpId
!=
ReduceTensorOp
::
ADD
&&
ReduceOpId
!=
ReduceTensorOp
::
AVG
);
// 1) If InDataType is bhalf_t, must use float as AccDataType for all reduction operations
constexpr
bool
invalid_reduce_6
=
std
::
is_same
<
InDataType
,
bhalf_t
>::
value
&&
!
std
::
is_same
<
AccDataType
,
float
>::
value
;
constexpr
bool
invalid_reduce
=
(
invalid_reduce_1
||
invalid_reduce_2
||
invalid_reduce_3
||
invalid_reduce_4
||
invalid_reduce_5
||
invalid_reduce_6
);
if
constexpr
(
!
invalid_reduce
)
{
Tensor
<
InDataType
>
in
(
inLengths
);
const
std
::
vector
<
int
>
OuterDims
=
get_invariant_dims
<
Rank
,
ReduceDims_
>
();
const
std
::
vector
<
int
>
ReduceDims
=
get_reduce_dims
<
Rank
,
ReduceDims_
>
();
std
::
vector
<
size_t
>
outLengths
;
if
(
OuterDims
.
empty
())
const
auto
invariantDims
=
get_invariant_dims
<
Rank
,
NumReduceDim
>
(
reduceDims
);
if
(
reduceDims
.
size
()
==
Rank
)
outLengths
.
push_back
(
1
);
else
for
(
auto
dim
:
Outer
Dims
)
for
(
auto
dim
:
invariant
Dims
)
outLengths
.
push_back
(
inLengths
[
dim
]);
Tensor
<
OutDataType
>
out_ref
(
outLengths
);
Tensor
<
OutDataType
>
out
(
outLengths
);
Tensor
<
int
>
out_indices_ref
(
outLengths
);
Tensor
<
int
>
out_indices
(
outLengths
);
Tensor
<
int
32_t
>
out_indices_ref
(
outLengths
);
Tensor
<
int
32_t
>
out_indices
(
outLengths
);
auto
inStrides
=
in
.
mDesc
.
GetStrides
();
auto
outStrides
=
out
.
mDesc
.
GetStrides
();
...
...
@@ -225,26 +244,28 @@ void profile_reduce_impl_impl(bool do_verification,
size_t
invariant_total_length
=
out
.
mDesc
.
GetElementSize
();
size_t
reduce_total_length
=
in
.
mDesc
.
GetElementSize
()
/
invariant_total_length
;
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
()
;
std
::
size_t
num_thread
=
1
;
if
(
do_verification
)
{
switch
(
init_method
)
{
case
0
:
in
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{},
num_thread
);
case
0
:
break
;
case
1
:
in
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{
1
},
num_thread
);
if
(
beta
!=
0.0
f
)
out_ref
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{},
num_thread
);
out_ref
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{
1
},
num_thread
);
break
;
case
1
:
case
2
:
in
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
},
num_thread
);
if
(
beta
!=
0.0
f
)
out_ref
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
},
num_thread
);
break
;
default:
in
.
GenerateTensorValue
(
GeneratorTensor_
2
<
InDataType
>
{
1
,
5
},
num_thread
);
in
.
GenerateTensorValue
(
GeneratorTensor_
3
<
InDataType
>
{
-
5.0
,
5
.0
},
num_thread
);
if
(
beta
!=
0.0
f
)
out_ref
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
1
,
5
},
num_thread
);
out_ref
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
-
5.0
,
5.0
},
num_thread
);
}
if
(
beta
!=
0.0
f
)
...
...
@@ -302,7 +323,7 @@ void profile_reduce_impl_impl(bool do_verification,
AccDataType
,
OutDataType
,
Rank
,
ReduceDim
s_
,
Num
ReduceDim
,
ReduceOpId
,
NanOpt
,
IndicesOpt
>
(
reduce0_ptrs
);
...
...
@@ -311,40 +332,46 @@ void profile_reduce_impl_impl(bool do_verification,
AccDataType
,
OutDataType
,
Rank
,
ReduceDim
s_
,
Num
ReduceDim
,
ReduceOpId
,
NanOpt
,
IndicesOpt
>
(
reduce0_ptrs
);
if
constexpr
(
use_atomic_add
)
{
add_device_reduce_instance_multiblock_atomic_add
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
ReduceDim
s_
,
Num
ReduceDim
,
ReduceOpId
,
NanOpt
,
IndicesOpt
>
(
reduce0_ptrs
);
}
else
{
add_device_reduce_instance_multiblock_partial_reduce
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
ReduceDim
s_
,
Num
ReduceDim
,
ReduceOpId
,
NanOpt
,
IndicesOpt
>
(
reduce1_ptrs
);
};
// used for secondary reduction
if
constexpr
(
!
use_atomic_add
)
{
add_device_reduce_instance_blockwise_second_call
<
AccDataType
,
AccDataType
,
OutDataType
,
Rank
,
ReduceDim
s_
,
Num
ReduceDim
,
ReduceOpId
,
NanOpt
,
IndicesOpt
>
(
reduce2_ptrs
);
};
if
(
reduce0_ptrs
.
empty
()
&&
reduce1_ptrs
.
empty
())
{
...
...
@@ -353,17 +380,24 @@ void profile_reduce_impl_impl(bool do_verification,
if
(
do_verification
)
{
using
hIn
Type
=
typename
type_mapping
<
InDataType
>::
o
ut
Data
Type
;
using
hOut
Type
=
typename
type_mapping
<
OutDataType
>::
o
ut
Data
Type
;
using
hComp
Type
=
typename
type_mapping
<
AccDataType
>::
o
ut
Data
Type
;
using
HostInData
Type
=
typename
type_mapping
<
InDataType
>::
O
utType
;
using
HostOutData
Type
=
typename
type_mapping
<
OutDataType
>::
O
utType
;
using
HostAccData
Type
=
typename
type_mapping
<
AccDataType
>::
O
utType
;
ReductionHost
<
hInType
,
hCompType
,
hOutType
,
ReduceOpId
,
PropagateNan
,
NeedIndices
>
hostReduce
(
in
.
mDesc
,
out_ref
.
mDesc
,
OuterDims
,
ReduceDims
);
ReductionHost
<
HostInDataType
,
HostAccDataType
,
HostOutDataType
,
ReduceOpId
,
Rank
,
NumReduceDim
,
PropagateNan
,
NeedIndices
>
hostReduce
(
in
.
mDesc
,
out_ref
.
mDesc
,
invariantDims
,
reduceDims
);
hostReduce
.
Run
(
alpha
,
reinterpret_cast
<
const
hIn
Type
*>
(
in
.
mData
.
data
()),
reinterpret_cast
<
const
HostInData
Type
*>
(
in
.
mData
.
data
()),
beta
,
reinterpret_cast
<
hOut
Type
*>
(
out_ref
.
mData
.
data
()),
reinterpret_cast
<
HostOutData
Type
*>
(
out_ref
.
mData
.
data
()),
out_indices_ref
.
mData
.
data
());
};
...
...
@@ -374,23 +408,27 @@ void profile_reduce_impl_impl(bool do_verification,
for
(
auto
&
reduce_ptr
:
reduce0_ptrs
)
{
auto
wsSizeInBytes
=
reduce_ptr
->
GetWorkspaceSizeInBytes
(
i_inLengths
);
auto
wsSizeInBytes
=
reduce_ptr
->
GetWorkspaceSizeInBytes
(
i_inLengths
,
reduceDims
);
DeviceMem
ws_dev
(
wsSizeInBytes
);
auto
argument_ptr
=
reduce_ptr
->
MakeArgumentPointer
(
i_inLengths
,
InElementwiseOperation_0
in_elementwise_op_0
(
static_cast
<
int32_t
>
(
reduce_total_length
));
AccElementwiseOperation_0
acc_elementwise_op_0
(
static_cast
<
int32_t
>
(
reduce_total_length
));
auto
argument_ptr
=
reduce_ptr
->
MakeArgumentPointer
(
i_inLengths
,
i_inStrides
,
i_outLengths
,
i_outStrides
,
reduceDims
,
alpha
,
beta
,
in_dev
.
GetDeviceBuffer
(),
out_dev
.
GetDeviceBuffer
(),
out_indices_dev
.
GetDeviceBuffer
(),
ws_dev
.
GetDeviceBuffer
(),
InElementwiseOperation_0
{
static_cast
<
int32_t
>
(
reduce_total_length
)}
,
AccElementwiseOperation_0
{
static_cast
<
int32_t
>
(
reduce_total_length
)}
);
in_elementwise_op_0
,
acc_elementwise_op_0
);
if
(
!
reduce_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
continue
;
...
...
@@ -419,12 +457,13 @@ void profile_reduce_impl_impl(bool do_verification,
if
(
do_verification
)
{
out_dev
.
FromDevice
(
out
.
mData
.
data
());
c
heck_error
(
out_ref
,
out
);
c
k
::
utils
::
check_err
(
out
.
mData
,
out_ref
.
mData
);
if
(
NeedIndices
)
{
out_indices_dev
.
FromDevice
(
out_indices
.
mData
.
data
());
check_indices
(
out_indices_ref
,
out_indices
);
ck
::
utils
::
check_err
(
out_indices
.
mData
,
out_indices_ref
.
mData
);
;
};
if
(
do_log
)
...
...
@@ -455,23 +494,27 @@ void profile_reduce_impl_impl(bool do_verification,
for
(
auto
&
reduce_ptr
:
reduce1_ptrs
)
{
auto
wsSizeInBytes
=
reduce_ptr
->
GetWorkspaceSizeInBytes
(
i_inLengths
);
auto
wsSizeInBytes
=
reduce_ptr
->
GetWorkspaceSizeInBytes
(
i_inLengths
,
reduceDims
);
DeviceMem
ws_dev
(
wsSizeInBytes
);
auto
argument_ptr
=
reduce_ptr
->
MakeArgumentPointer
(
i_inLengths
,
InElementwiseOperation_1
in_elementwise_op_1
(
static_cast
<
int32_t
>
(
reduce_total_length
));
AccElementwiseOperation_1
acc_elementwise_op_1
(
static_cast
<
int32_t
>
(
reduce_total_length
));
auto
argument_ptr
=
reduce_ptr
->
MakeArgumentPointer
(
i_inLengths
,
i_inStrides
,
i_outLengths
,
i_outStrides
,
reduceDims
,
alpha
,
beta
,
in_dev
.
GetDeviceBuffer
(),
out_dev
.
GetDeviceBuffer
(),
out_indices_dev
.
GetDeviceBuffer
(),
ws_dev
.
GetDeviceBuffer
(),
InElementwiseOperation_1
{
static_cast
<
int32_t
>
(
reduce_total_length
)}
,
AccElementwiseOperation_1
{
static_cast
<
int32_t
>
(
reduce_total_length
)}
);
in_elementwise_op_1
,
acc_elementwise_op_1
);
if
(
!
reduce_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
continue
;
...
...
@@ -491,19 +534,25 @@ void profile_reduce_impl_impl(bool do_verification,
for
(
auto
&
reduce2_ptr
:
reduce2_ptrs
)
{
auto
argument2_ptr
=
reduce2_ptr
->
MakeArgumentPointer
(
inLengths2
,
InElementwiseOperation_2
in_elementwise_op_2
(
static_cast
<
int32_t
>
(
reduce_total_length
));
AccElementwiseOperation_2
acc_elementwise_op_2
(
static_cast
<
int32_t
>
(
reduce_total_length
));
auto
argument2_ptr
=
reduce2_ptr
->
MakeArgumentPointer
(
inLengths2
,
inStrides2
,
i_outLengths
,
i_outStrides
,
reduceDims
,
alpha
,
beta
,
ws_dev
.
GetDeviceBuffer
(),
out_dev
.
GetDeviceBuffer
(),
out_indices_dev
.
GetDeviceBuffer
(),
ws_dev
.
GetDeviceBuffer
(),
InElementwiseOperation_2
{
static_cast
<
int32_t
>
(
reduce_total_length
)}
,
AccElementwiseOperation_2
{
static_cast
<
int32_t
>
(
reduce_total_length
)}
);
in_elementwise_op_2
,
acc_elementwise_op_2
);
if
(
!
reduce2_ptr
->
IsSupportedArgument
(
argument2_ptr
.
get
()))
continue
;
...
...
@@ -531,12 +580,13 @@ void profile_reduce_impl_impl(bool do_verification,
if
(
do_verification
)
{
out_dev
.
FromDevice
(
out
.
mData
.
data
());
c
heck_error
(
out_ref
,
out
);
c
k
::
utils
::
check_err
(
out
.
mData
,
out_ref
.
mData
);
if
(
NeedIndices
)
{
out_indices_dev
.
FromDevice
(
out_indices
.
mData
.
data
());
check_indices
(
out_indices_ref
,
out_indices
);
ck
::
utils
::
check_err
(
out_indices
.
mData
,
out_indices_ref
.
mData
);
;
};
if
(
do_log
)
...
...
@@ -584,10 +634,10 @@ void profile_reduce_impl(bool do_verification,
bool
do_dumpout
,
int
nrepeat
,
const
std
::
vector
<
size_t
>&
inLengths
,
const
std
::
vector
<
int
>&
R
educeDims
,
ReduceTensorOp
_t
ReduceOpId
,
NanPropagation
_t
NanOpt
,
ReduceTensorIndices
_t
IndicesOpt
,
const
std
::
vector
<
int
>&
r
educeDims
,
ReduceTensorOp
ReduceOpId
,
NanPropagation
NanOpt
,
ReduceTensorIndices
IndicesOpt
,
float
alpha
,
float
beta
)
{
...
...
@@ -605,18 +655,26 @@ void profile_reduce_impl(bool do_verification,
using
descType
=
remove_cvref_t
<
decltype
(
std
::
get
<
i
>
(
tuple_object
))
>
;
if
(
!
description_match
(
descType
{},
inLengths
.
size
(),
R
educeDims
,
ReduceOpId
,
NanOpt
,
IndicesOpt
))
descType
{},
inLengths
.
size
(),
r
educeDims
,
ReduceOpId
,
NanOpt
,
IndicesOpt
))
return
;
profile_reduce_impl_impl
<
InDataType
,
AccDataType
,
OutDataType
,
descType
::
Rank_
,
typename
descType
::
ReduceDims_
,
static_cast
<
ReduceTensorOp_t
>
(
descType
::
ReduceOpId_
),
static_cast
<
NanPropagation_t
>
(
descType
::
NanOpt_
),
static_cast
<
ReduceTensorIndices_t
>
(
descType
::
IndicesOpt_
)
>
(
do_verification
,
init_method
,
do_log
,
do_dumpout
,
nrepeat
,
inLengths
,
alpha
,
beta
);
descType
::
NumReduceDim_
,
static_cast
<
ReduceTensorOp
>
(
descType
::
ReduceOpId_
),
static_cast
<
NanPropagation
>
(
descType
::
NanOpt_
),
static_cast
<
ReduceTensorIndices
>
(
descType
::
IndicesOpt_
)
>
(
do_verification
,
init_method
,
do_log
,
do_dumpout
,
nrepeat
,
inLengths
,
reduceDims
,
alpha
,
beta
);
matched
=
true
;
});
...
...
profiler/src/profile_batched_gemm.cpp
View file @
dd6a8de4
#include <cstdint>
#include <iostream>
#include <numeric>
#include <initializer_list>
...
...
@@ -15,7 +16,7 @@
#include "device_batched_gemm_xdl.hpp"
#include "profile_batched_gemm_impl.hpp"
enum
GemmMatrixLayout
enum
struct
GemmMatrixLayout
{
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
...
...
@@ -27,10 +28,12 @@ enum GemmMatrixLayout
KM_NK_NM
,
// 7
};
enum
GemmDataType
enum
struct
GemmDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
BF16_BF16_BF16
,
// 2
INT8_INT8_INT8
,
// 3
};
int
profile_batched_gemm
(
int
argc
,
char
*
argv
[])
...
...
@@ -38,7 +41,7 @@ int profile_batched_gemm(int argc, char* argv[])
if
(
!
(
argc
==
15
))
{
printf
(
"arg1: tensor operation (batched_gemm: Batched GEMM)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16
, 2: bf16, 3: int8
)
\n
"
);
printf
(
"arg3: matrix layout (0: A[g, m, k] * B[g, k, n] = C[g, m, n];
\n
"
);
printf
(
" 1: A[g, m, k] * B[g, n, k] = C[g, m, n];
\n
"
);
printf
(
" 2: A[g, k, m] * B[g, k, n] = C[g, m, n];
\n
"
);
...
...
@@ -51,8 +54,8 @@ int profile_batched_gemm(int argc, char* argv[])
exit
(
1
);
}
const
int
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
...
...
@@ -146,6 +149,241 @@ int profile_batched_gemm(int argc, char* argv[])
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
BatchCount
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
BatchCount
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
BatchCount
);
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
BatchCount
);
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
BatchCount
);
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
BatchCount
);
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
BatchCount
);
}
else
{
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
...
...
profiler/src/profile_batched_gemm_reduce.cpp
0 → 100644
View file @
dd6a8de4
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "profile_batched_gemm_reduce_impl.hpp"
int
profile_batched_gemm_reduce
(
int
argc
,
char
*
argv
[])
{
enum
struct
GemmMatrixLayout
{
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
KM_KN_MN
,
// 2
KM_NK_MN
,
// 3
};
enum
struct
GemmReduceDataType
{
F32_F32_F32_F32_F32
,
// 0
F16_F16_F16_F32_F32
,
// 1
};
if
(
!
(
argc
==
15
||
argc
==
16
))
{
printf
(
"arg1: tensor operation (batched_gemm: BatchedGEMM+Reduce)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16)
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 3: A[k, m] * B[n, k] = C[m, n])
\n
"
);
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7: run kernel # of times (>1)
\n
"
);
printf
(
"arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount
\n
"
);
printf
(
"arg15: split k into mulitiple batch
\n
"
);
exit
(
1
);
}
const
auto
data_type
=
static_cast
<
GemmReduceDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
7
]);
const
int
M
=
std
::
stoi
(
argv
[
8
]);
const
int
N
=
std
::
stoi
(
argv
[
9
]);
const
int
K
=
std
::
stoi
(
argv
[
10
]);
const
int
StrideA
=
std
::
stoi
(
argv
[
11
]);
const
int
StrideB
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
const
int
BatchCount
=
std
::
stoi
(
argv
[
14
]);
if
(
data_type
==
GemmReduceDataType
::
F16_F16_F16_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_batched_gemm_reduce_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
BatchCount
);
}
else
if
(
data_type
==
GemmReduceDataType
::
F16_F16_F16_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
ck
::
profiler
::
profile_batched_gemm_reduce_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
BatchCount
);
}
else
if
(
data_type
==
GemmReduceDataType
::
F16_F16_F16_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
ck
::
profiler
::
profile_batched_gemm_reduce_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
BatchCount
);
}
else
if
(
data_type
==
GemmReduceDataType
::
F16_F16_F16_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
ck
::
profiler
::
profile_batched_gemm_reduce_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
BatchCount
);
}
else
{
throw
std
::
runtime_error
(
"wrong! this data_type & layout is not implemented"
);
}
return
1
;
}
profiler/src/profile_conv_bwd_data.cpp
View file @
dd6a8de4
...
...
@@ -6,7 +6,7 @@
#include <half.hpp>
#include "profile_conv_bwd_data_impl.hpp"
enum
ConvDataType
enum
struct
ConvDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
...
...
@@ -14,19 +14,19 @@ enum ConvDataType
INT8_INT8_INT8
,
// 3
};
enum
ConvInputLayout
enum
struct
ConvInputLayout
{
NCHW
,
// 0
NHWC
,
// 1
};
enum
ConvWeightLayout
enum
struct
ConvWeightLayout
{
KCYX
,
// 0
KYXC
,
// 1
};
enum
ConvOutputLayout
enum
struct
ConvOutputLayout
{
NKHW
,
// 0
NHWK
,
// 1
...
...
@@ -50,10 +50,10 @@ int profile_conv_bwd_data(int argc, char* argv[])
exit
(
1
);
}
const
int
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
int
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
int
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
auto
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
auto
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
...
...
@@ -89,6 +89,7 @@ int profile_conv_bwd_data(int argc, char* argv[])
float
,
float
,
float
,
float
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
...
...
@@ -114,6 +115,7 @@ int profile_conv_bwd_data(int argc, char* argv[])
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
...
...
@@ -139,6 +141,7 @@ int profile_conv_bwd_data(int argc, char* argv[])
uint16_t
,
uint16_t
,
uint16_t
,
float
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
...
...
@@ -164,6 +167,7 @@ int profile_conv_bwd_data(int argc, char* argv[])
int8_t
,
int8_t
,
int8_t
,
int32_t
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
...
...
profiler/src/profile_conv_bwd_weight.cpp
0 → 100644
View file @
dd6a8de4
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "profile_conv_bwd_weight_impl.hpp"
enum
struct
ConvDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
BF16_BF16_BF16
,
// 2
INT8_INT8_INT8
,
// 3
};
enum
struct
ConvInputLayout
{
NCHW
,
// 0
NHWC
,
// 1
};
enum
struct
ConvWeightLayout
{
KCYX
,
// 0
KYXC
,
// 1
};
enum
struct
ConvOutputLayout
{
NKHW
,
// 0
NHWK
,
// 1
};
int
profile_conv_bwd_weight
(
int
argc
,
char
*
argv
[])
{
if
(
argc
!=
26
)
{
printf
(
"arg1: tensor operation (conv_fwd: ForwardConvolution)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16)
\n
"
);
printf
(
"arg3: input tensor layout (0: NCHW; 1: NHWC)
\n
"
);
printf
(
"arg4: weight tensor layout (0: KCYX; 1: KYXC)
\n
"
);
printf
(
"arg5: output tensor layout (0: NKHW; 1: NHWK)
\n
"
);
printf
(
"arg6: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg7: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg9: run kernel # of times (>1)
\n
"
);
printf
(
"arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
printf
(
"arg25: split k (>=1)
\n
"
);
exit
(
1
);
}
const
auto
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
auto
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
9
]);
const
ck
::
index_t
N
=
std
::
stoi
(
argv
[
10
]);
const
ck
::
index_t
K
=
std
::
stoi
(
argv
[
11
]);
const
ck
::
index_t
C
=
std
::
stoi
(
argv
[
12
]);
const
ck
::
index_t
Y
=
std
::
stoi
(
argv
[
13
]);
const
ck
::
index_t
X
=
std
::
stoi
(
argv
[
14
]);
const
ck
::
index_t
Hi
=
std
::
stoi
(
argv
[
15
]);
const
ck
::
index_t
Wi
=
std
::
stoi
(
argv
[
16
]);
const
ck
::
index_t
conv_stride_h
=
std
::
stoi
(
argv
[
17
]);
const
ck
::
index_t
conv_stride_w
=
std
::
stoi
(
argv
[
18
]);
const
ck
::
index_t
conv_dilation_h
=
std
::
stoi
(
argv
[
19
]);
const
ck
::
index_t
conv_dilation_w
=
std
::
stoi
(
argv
[
20
]);
const
ck
::
index_t
in_left_pad_h
=
std
::
stoi
(
argv
[
21
]);
const
ck
::
index_t
in_left_pad_w
=
std
::
stoi
(
argv
[
22
]);
const
ck
::
index_t
in_right_pad_h
=
std
::
stoi
(
argv
[
23
]);
const
ck
::
index_t
in_right_pad_w
=
std
::
stoi
(
argv
[
24
]);
ck
::
index_t
split_k
=
std
::
stoi
(
argv
[
25
]);
split_k
=
std
::
max
(
1
,
split_k
);
const
ck
::
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
ck
::
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
YEff
)
/
conv_stride_h
+
1
;
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
1
;
if
(
data_type
==
ConvDataType
::
F32_F32_F32
&&
in_layout
==
ConvInputLayout
::
NHWC
&&
wei_layout
==
ConvWeightLayout
::
KYXC
&&
out_layout
==
ConvOutputLayout
::
NHWK
)
{
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
2
,
float
,
float
,
float
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
N
,
K
,
C
,
std
::
vector
<
ck
::
index_t
>
{
Hi
,
Wi
},
std
::
vector
<
ck
::
index_t
>
{
Y
,
X
},
std
::
vector
<
ck
::
index_t
>
{
Ho
,
Wo
},
std
::
vector
<
ck
::
index_t
>
{
conv_stride_h
,
conv_stride_w
},
std
::
vector
<
ck
::
index_t
>
{
conv_dilation_h
,
conv_dilation_w
},
std
::
vector
<
ck
::
index_t
>
{
in_left_pad_h
,
in_left_pad_w
},
std
::
vector
<
ck
::
index_t
>
{
in_right_pad_h
,
in_right_pad_w
},
split_k
);
}
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
&&
in_layout
==
ConvInputLayout
::
NHWC
&&
wei_layout
==
ConvWeightLayout
::
KYXC
&&
out_layout
==
ConvOutputLayout
::
NHWK
)
{
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
2
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
N
,
K
,
C
,
std
::
vector
<
ck
::
index_t
>
{
Hi
,
Wi
},
std
::
vector
<
ck
::
index_t
>
{
Y
,
X
},
std
::
vector
<
ck
::
index_t
>
{
Ho
,
Wo
},
std
::
vector
<
ck
::
index_t
>
{
conv_stride_h
,
conv_stride_w
},
std
::
vector
<
ck
::
index_t
>
{
conv_dilation_h
,
conv_dilation_w
},
std
::
vector
<
ck
::
index_t
>
{
in_left_pad_h
,
in_left_pad_w
},
std
::
vector
<
ck
::
index_t
>
{
in_right_pad_h
,
in_right_pad_w
},
split_k
);
}
else
{
throw
std
::
runtime_error
(
"wrong! this Conv data_type & layout is not implemented"
);
}
return
1
;
}
profiler/src/profile_conv_fwd.cpp
View file @
dd6a8de4
...
...
@@ -6,7 +6,7 @@
#include <half.hpp>
#include "profile_conv_fwd_impl.hpp"
enum
ConvDataType
enum
struct
ConvDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
...
...
@@ -14,19 +14,19 @@ enum ConvDataType
INT8_INT8_INT8
,
// 3
};
enum
ConvInputLayout
enum
struct
ConvInputLayout
{
NCHW
,
// 0
NHWC
,
// 1
};
enum
ConvWeightLayout
enum
struct
ConvWeightLayout
{
KCYX
,
// 0
KYXC
,
// 1
};
enum
ConvOutputLayout
enum
struct
ConvOutputLayout
{
NKHW
,
// 0
NHWK
,
// 1
...
...
@@ -50,10 +50,10 @@ int profile_conv_fwd(int argc, char* argv[])
exit
(
1
);
}
const
int
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
int
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
int
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
auto
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
auto
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
...
...
profiler/src/profile_conv_fwd_bias_relu.cpp
View file @
dd6a8de4
...
...
@@ -6,25 +6,25 @@
#include <half.hpp>
#include "profile_conv_fwd_bias_relu_impl.hpp"
enum
ConvDataType
enum
struct
ConvDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
};
enum
ConvInputLayout
enum
struct
ConvInputLayout
{
NCHW
,
// 0
NHWC
,
// 1
};
enum
ConvWeightLayout
enum
struct
ConvWeightLayout
{
KCYX
,
// 0
KYXC
,
// 1
};
enum
ConvOutputLayout
enum
struct
ConvOutputLayout
{
NKHW
,
// 0
NHWK
,
// 1
...
...
@@ -48,10 +48,10 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
exit
(
1
);
}
const
int
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
int
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
int
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
auto
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
auto
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
...
...
profiler/src/profile_conv_fwd_bias_relu_add.cpp
View file @
dd6a8de4
...
...
@@ -6,25 +6,25 @@
#include <half.hpp>
#include "profile_conv_fwd_bias_relu_add_impl.hpp"
enum
ConvDataType
enum
struct
ConvDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
};
enum
ConvInputLayout
enum
struct
ConvInputLayout
{
NCHW
,
// 0
NHWC
,
// 1
};
enum
ConvWeightLayout
enum
struct
ConvWeightLayout
{
KCYX
,
// 0
KYXC
,
// 1
};
enum
ConvOutputLayout
enum
struct
ConvOutputLayout
{
NKHW
,
// 0
NHWK
,
// 1
...
...
@@ -49,10 +49,10 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
exit
(
1
);
}
const
int
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
int
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
int
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
auto
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
auto
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
...
...
profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp
View file @
dd6a8de4
...
...
@@ -6,25 +6,25 @@
#include <half.hpp>
#include "profile_conv_fwd_bias_relu_atomic_add_impl.hpp"
enum
ConvDataType
enum
struct
ConvDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
};
enum
ConvInputLayout
enum
struct
ConvInputLayout
{
NCHW
,
// 0
NHWC
,
// 1
};
enum
ConvWeightLayout
enum
struct
ConvWeightLayout
{
KCYX
,
// 0
KYXC
,
// 1
};
enum
ConvOutputLayout
enum
struct
ConvOutputLayout
{
NKHW
,
// 0
NHWK
,
// 1
...
...
@@ -49,10 +49,10 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[])
exit
(
1
);
}
const
int
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
int
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
int
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
auto
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
auto
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
...
...
profiler/src/profile_convnd_bwd_data.cpp
0 → 100644
View file @
dd6a8de4
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "profile_convnd_bwd_data_impl.hpp"
enum
struct
ConvDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
BF16_BF16_BF16
,
// 2
INT8_INT8_INT8
,
// 3
};
enum
struct
ConvInputLayout
{
NCHW
,
// 0
NHWC
,
// 1
};
enum
struct
ConvWeightLayout
{
KCYX
,
// 0
KYXC
,
// 1
};
enum
struct
ConvOutputLayout
{
NKHW
,
// 0
NHWK
,
// 1
};
ck
::
utils
::
conv
::
ConvParams
parse_conv_params
(
int
num_dim_spatial
,
char
*
argv
[],
int
arg_idx
)
{
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
ck
::
utils
::
conv
::
ConvParams
params
;
params
.
num_dim_spatial
=
num_dim_spatial
;
params
.
N
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
K
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
C
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
filter_spatial_lengths
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
filter_spatial_lengths
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
input_spatial_lengths
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
input_spatial_lengths
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
conv_filter_strides
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
conv_filter_strides
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
conv_filter_dilations
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
conv_filter_dilations
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
input_left_pads
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
input_left_pads
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
input_right_pads
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
input_right_pads
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
return
params
;
}
int
profile_convnd_bwd_data
(
int
argc
,
char
*
argv
[],
int
num_dim_spatial
)
{
const
int
preParams
=
10
;
int
conv_args
=
3
+
num_dim_spatial
*
6
;
int
cmdline_nargs
=
conv_args
+
preParams
;
if
(
cmdline_nargs
!=
argc
)
{
printf
(
"arg1: tensor operation (conv[1|2|3]d_bwd_data: BackwardConvolution)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16)
\n
"
);
printf
(
"arg3: input tensor layout (0: NCHW; 1: NHWC)
\n
"
);
printf
(
"arg4: weight tensor layout (0: KCYX; 1: KYXC)
\n
"
);
printf
(
"arg5: output tensor layout (0: NKHW; 1: NHWK)
\n
"
);
printf
(
"arg6: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg7: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg9: run kernel # of times (>1)
\n
"
);
printf
(
"arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
return
1
;
}
const
auto
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
auto
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
9
]);
ck
::
utils
::
conv
::
ConvParams
params
=
parse_conv_params
(
num_dim_spatial
,
argv
,
preParams
);
auto
Run
=
[
&
](
auto
input_type
,
auto
wei_type
,
auto
out_type
,
auto
acc_type
)
{
using
InDataType
=
decltype
(
input_type
);
using
WeiDataType
=
decltype
(
wei_type
);
using
OutDataType
=
decltype
(
out_type
);
using
AccDataType
=
decltype
(
acc_type
);
switch
(
num_dim_spatial
)
{
case
1
:
ck
::
profiler
::
profile_convnd_bwd_data_impl
<
1
,
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
NWK
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
params
.
N
,
params
.
K
,
params
.
C
,
params
.
input_spatial_lengths
,
params
.
filter_spatial_lengths
,
params
.
GetOutputSpatialLengths
(),
params
.
conv_filter_strides
,
params
.
conv_filter_dilations
,
params
.
input_left_pads
,
params
.
input_right_pads
);
break
;
case
2
:
ck
::
profiler
::
profile_convnd_bwd_data_impl
<
2
,
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
params
.
N
,
params
.
K
,
params
.
C
,
params
.
input_spatial_lengths
,
params
.
filter_spatial_lengths
,
params
.
GetOutputSpatialLengths
(),
params
.
conv_filter_strides
,
params
.
conv_filter_dilations
,
params
.
input_left_pads
,
params
.
input_right_pads
);
break
;
case
3
:
ck
::
profiler
::
profile_convnd_bwd_data_impl
<
3
,
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
params
.
N
,
params
.
K
,
params
.
C
,
params
.
input_spatial_lengths
,
params
.
filter_spatial_lengths
,
params
.
GetOutputSpatialLengths
(),
params
.
conv_filter_strides
,
params
.
conv_filter_dilations
,
params
.
input_left_pads
,
params
.
input_right_pads
);
break
;
default:
break
;
}
};
if
(
data_type
==
ConvDataType
::
F32_F32_F32
&&
in_layout
==
ConvInputLayout
::
NHWC
&&
wei_layout
==
ConvWeightLayout
::
KYXC
&&
out_layout
==
ConvOutputLayout
::
NHWK
)
{
Run
(
float
{},
float
{},
float
{},
float
{});
}
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
&&
in_layout
==
ConvInputLayout
::
NHWC
&&
wei_layout
==
ConvWeightLayout
::
KYXC
&&
out_layout
==
ConvOutputLayout
::
NHWK
)
{
Run
(
ck
::
half_t
{},
ck
::
half_t
{},
ck
::
half_t
{},
float
{});
}
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
&&
in_layout
==
ConvInputLayout
::
NHWC
&&
wei_layout
==
ConvWeightLayout
::
KYXC
&&
out_layout
==
ConvOutputLayout
::
NHWK
)
{
Run
(
ck
::
bhalf_t
{},
ck
::
bhalf_t
{},
ck
::
bhalf_t
{},
float
{});
}
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
&&
in_layout
==
ConvInputLayout
::
NHWC
&&
wei_layout
==
ConvWeightLayout
::
KYXC
&&
out_layout
==
ConvOutputLayout
::
NHWK
)
{
Run
(
int8_t
{},
int8_t
{},
int8_t
{},
int32_t
{});
}
else
{
std
::
cout
<<
"wrong! this Conv data_type & layout is not implemented"
<<
std
::
endl
;
return
1
;
}
return
0
;
}
profiler/src/profile_gemm.cpp
View file @
dd6a8de4
...
...
@@ -6,7 +6,7 @@
#include <half.hpp>
#include "profile_gemm_impl.hpp"
enum
GemmMatrixLayout
enum
struct
GemmMatrixLayout
{
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
...
...
@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM
,
// 7
};
enum
GemmDataType
enum
struct
GemmDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
...
...
@@ -45,8 +45,8 @@ int profile_gemm(int argc, char* argv[])
exit
(
1
);
}
const
int
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
...
...
@@ -223,6 +223,26 @@ int profile_gemm(int argc, char* argv[])
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
...
...
@@ -243,6 +263,66 @@ int profile_gemm(int argc, char* argv[])
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
...
...
@@ -263,6 +343,46 @@ int profile_gemm(int argc, char* argv[])
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
}
else
{
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
...
...
profiler/src/profile_gemm_bias_2d.cpp
View file @
dd6a8de4
...
...
@@ -6,7 +6,7 @@
#include <half.hpp>
#include "profile_gemm_bias_2d_impl.hpp"
enum
GemmMatrixLayout
enum
struct
GemmMatrixLayout
{
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
...
...
@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM
,
// 7
};
enum
GemmDataType
enum
struct
GemmDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
...
...
@@ -45,8 +45,8 @@ int profile_gemm_bias_2d(int argc, char* argv[])
exit
(
1
);
}
const
int
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
...
...
profiler/src/profile_gemm_bias_relu.cpp
View file @
dd6a8de4
...
...
@@ -6,7 +6,7 @@
#include <half.hpp>
#include "profile_gemm_bias_relu_impl.hpp"
enum
GemmMatrixLayout
enum
struct
GemmMatrixLayout
{
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
...
...
@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM
,
// 7
};
enum
GemmDataType
enum
struct
GemmDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
...
...
@@ -43,8 +43,8 @@ int profile_gemm_bias_relu(int argc, char* argv[])
exit
(
1
);
}
const
int
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
...
...
profiler/src/profile_gemm_bias_relu_add.cpp
View file @
dd6a8de4
...
...
@@ -6,7 +6,7 @@
#include <half.hpp>
#include "profile_gemm_bias_relu_add_impl.hpp"
enum
GemmMatrixLayout
enum
struct
GemmMatrixLayout
{
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
...
...
@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM
,
// 7
};
enum
GemmDataType
enum
struct
GemmDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
...
...
@@ -43,8 +43,8 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
exit
(
1
);
}
const
int
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
...
...
Prev
1
…
17
18
19
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment