Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e72c0c43
Commit
e72c0c43
authored
Mar 26, 2022
by
carlushuang
Browse files
Merge remote-tracking branch 'origin/develop' into cpu_avx2
parents
d714fa15
313bbea5
Changes
262
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1513 additions
and
150 deletions
+1513
-150
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.cpp
...pu/reduce/device_reduce_instance_threadwise_i8_i32_i8.cpp
+25
-0
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.cpp
...gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.cpp
+40
-0
profiler/CMakeLists.txt
profiler/CMakeLists.txt
+5
-0
profiler/include/profile_batched_gemm_impl.hpp
profiler/include/profile_batched_gemm_impl.hpp
+179
-13
profiler/include/profile_conv_bwd_data_impl.hpp
profiler/include/profile_conv_bwd_data_impl.hpp
+4
-4
profiler/include/profile_gemm_bias_2d_impl.hpp
profiler/include/profile_gemm_bias_2d_impl.hpp
+1
-1
profiler/include/profile_gemm_impl.hpp
profiler/include/profile_gemm_impl.hpp
+73
-7
profiler/include/profile_gemm_reduce_impl.hpp
profiler/include/profile_gemm_reduce_impl.hpp
+335
-0
profiler/include/profile_grouped_gemm_impl.hpp
profiler/include/profile_grouped_gemm_impl.hpp
+314
-0
profiler/include/profile_reduce_impl.hpp
profiler/include/profile_reduce_impl.hpp
+118
-64
profiler/src/README.md
profiler/src/README.md
+2
-2
profiler/src/profile_batched_gemm.cpp
profiler/src/profile_batched_gemm.cpp
+245
-7
profiler/src/profile_conv_bwd_data.cpp
profiler/src/profile_conv_bwd_data.cpp
+8
-8
profiler/src/profile_conv_fwd.cpp
profiler/src/profile_conv_fwd.cpp
+8
-8
profiler/src/profile_conv_fwd_bias_relu.cpp
profiler/src/profile_conv_fwd_bias_relu.cpp
+8
-8
profiler/src/profile_conv_fwd_bias_relu_add.cpp
profiler/src/profile_conv_fwd_bias_relu_add.cpp
+8
-8
profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp
profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp
+8
-8
profiler/src/profile_gemm.cpp
profiler/src/profile_gemm.cpp
+124
-4
profiler/src/profile_gemm_bias_2d.cpp
profiler/src/profile_gemm_bias_2d.cpp
+4
-4
profiler/src/profile_gemm_bias_relu.cpp
profiler/src/profile_gemm_bias_relu.cpp
+4
-4
No files found.
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.cpp
0 → 100644
View file @
e72c0c43
#include "device_reduce_instance_threadwise.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_reduce_instance
{
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int32_t
,
int8_t
,
0
,
0
,
0
,
4
,
3
);
// for ADD
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int32_t
,
int8_t
,
0
,
0
,
0
,
4
,
4
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int32_t
,
int8_t
,
0
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int32_t
,
int8_t
,
0
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int32_t
,
int8_t
,
5
,
0
,
0
,
4
,
3
);
// for AVG
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int32_t
,
int8_t
,
5
,
0
,
0
,
4
,
4
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int32_t
,
int8_t
,
5
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int32_t
,
int8_t
,
5
,
0
,
0
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.cpp
0 → 100644
View file @
e72c0c43
#include "device_reduce_instance_threadwise.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_reduce_instance
{
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
2
,
0
,
0
,
4
,
3
);
// for MIN
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
2
,
0
,
0
,
4
,
4
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
2
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
2
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
3
,
0
,
0
,
4
,
3
);
// for MAX
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
3
,
0
,
0
,
4
,
4
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
3
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
3
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
4
,
0
,
0
,
4
,
3
);
// for AMAX
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
4
,
0
,
0
,
4
,
4
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
4
,
0
,
0
,
4
,
1
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
4
,
0
,
0
,
2
,
1
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
2
,
0
,
1
,
4
,
3
);
// for MIN
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
2
,
0
,
1
,
4
,
4
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
2
,
0
,
1
,
4
,
1
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
2
,
0
,
1
,
2
,
1
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
3
,
0
,
1
,
4
,
3
);
// for MAX
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
3
,
0
,
1
,
4
,
4
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
3
,
0
,
1
,
4
,
1
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
3
,
0
,
1
,
2
,
1
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
4
,
0
,
1
,
4
,
3
);
// for AMAX
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
4
,
0
,
1
,
4
,
4
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
4
,
0
,
1
,
4
,
1
);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
4
,
0
,
1
,
2
,
1
);
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
profiler/CMakeLists.txt
View file @
e72c0c43
...
@@ -26,6 +26,7 @@ set(PROFILER_SOURCE
...
@@ -26,6 +26,7 @@ set(PROFILER_SOURCE
src/profile_gemm_bias_2d.cpp
src/profile_gemm_bias_2d.cpp
src/profile_gemm_bias_relu.cpp
src/profile_gemm_bias_relu.cpp
src/profile_gemm_bias_relu_add.cpp
src/profile_gemm_bias_relu_add.cpp
src/profile_gemm_reduce.cpp
src/profile_batched_gemm.cpp
src/profile_batched_gemm.cpp
src/profile_conv_fwd.cpp
src/profile_conv_fwd.cpp
src/profile_conv_fwd_bias_relu.cpp
src/profile_conv_fwd_bias_relu.cpp
...
@@ -33,11 +34,13 @@ set(PROFILER_SOURCE
...
@@ -33,11 +34,13 @@ set(PROFILER_SOURCE
src/profile_conv_fwd_bias_relu_atomic_add.cpp
src/profile_conv_fwd_bias_relu_atomic_add.cpp
src/profile_conv_bwd_data.cpp
src/profile_conv_bwd_data.cpp
src/profile_reduce.cpp
src/profile_reduce.cpp
src/profile_grouped_gemm.cpp
)
)
add_executable
(
ckProfiler
${
PROFILER_SOURCE
}
)
add_executable
(
ckProfiler
${
PROFILER_SOURCE
}
)
target_link_libraries
(
ckProfiler PRIVATE host_tensor
)
target_link_libraries
(
ckProfiler PRIVATE host_tensor
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_reduce_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_bias2d_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_bias2d_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_bias_relu_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_bias_relu_instance
)
...
@@ -49,3 +52,5 @@ target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instanc
...
@@ -49,3 +52,5 @@ target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instanc
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_bwd_data_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_bwd_data_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_reduce_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_reduce_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_reduce_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_grouped_gemm_instance
)
profiler/include/profile_batched_gemm_impl.hpp
View file @
e72c0c43
#pragma once
#pragma once
#include <memory>
#include "reference_batched_gemm.hpp"
#include "reference_batched_gemm.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -11,10 +12,30 @@ using DeviceGemmNoOpPtr =
...
@@ -11,10 +12,30 @@ using DeviceGemmNoOpPtr =
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
}
// namespace device_batched_gemm_instance
}
// namespace device_batched_gemm_instance
}
// namespace device
}
// namespace device
...
@@ -65,6 +86,8 @@ void profile_batched_gemm_impl(int do_verification,
...
@@ -65,6 +86,8 @@ void profile_batched_gemm_impl(int do_verification,
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_g_m_n_device_result
(
Tensor
<
CDataType
>
c_g_m_n_device_result
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
StrideC
,
CLayout
{}));
std
::
unique_ptr
<
Tensor
<
float
>>
c_f32_g_m_n_host_result
=
nullptr
;
std
::
unique_ptr
<
Tensor
<
float
>>
c_f32_g_m_n_device_result
=
nullptr
;
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_g_k_n: "
<<
b_g_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_g_k_n: "
<<
b_g_k_n
.
mDesc
<<
std
::
endl
;
...
@@ -95,21 +118,56 @@ void profile_batched_gemm_impl(int do_verification,
...
@@ -95,21 +118,56 @@ void profile_batched_gemm_impl(int do_verification,
if
(
do_verification
)
if
(
do_verification
)
{
{
using
ReferenceBatchedGemmInstance
=
if
constexpr
(
is_same
<
ADataType
,
ck
::
bhalf_t
>::
value
&&
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
is_same
<
BDataType
,
ck
::
bhalf_t
>::
value
&&
BDataType
,
is_same
<
CDataType
,
ck
::
bhalf_t
>::
value
)
CDataType
,
{
AElementOp
,
Tensor
<
float
>
a_f32_g_m_k
(
BElementOp
,
f_host_tensor_descriptor
(
BatchCount
,
M
,
K
,
StrideA
,
ALayout
{}));
CElementOp
>
;
Tensor
<
float
>
b_f32_g_k_n
(
f_host_tensor_descriptor
(
BatchCount
,
K
,
N
,
StrideB
,
BLayout
{}));
c_f32_g_m_n_host_result
=
std
::
make_unique
<
Tensor
<
float
>>
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
StrideC
,
CLayout
{}));
c_f32_g_m_n_device_result
=
std
::
make_unique
<
Tensor
<
float
>>
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
StrideC
,
CLayout
{}));
bf16_to_f32_
(
a_g_m_k
,
a_f32_g_m_k
);
bf16_to_f32_
(
b_g_k_n
,
b_f32_g_k_n
);
using
ReferenceBatchedGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
float
,
float
,
float
,
AElementOp
,
BElementOp
,
CElementOp
>
;
auto
ref_batched_gemm
=
ReferenceBatchedGemmInstance
{};
auto
ref_invoker
=
ref_batched_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_batched_gemm
.
MakeArgument
(
a_f32_g_m_k
,
b_f32_g_k_n
,
*
c_f32_g_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
}
else
{
using
ReferenceBatchedGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
auto
ref_batched_gemm
=
ReferenceBatchedGemmInstance
{};
auto
ref_batched_gemm
=
ReferenceBatchedGemmInstance
{};
auto
ref_invoker
=
ref_batched_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_batched_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_batched_gemm
.
MakeArgument
(
auto
ref_argument
=
ref_batched_gemm
.
MakeArgument
(
a_g_m_k
,
b_g_k_n
,
c_g_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
a_g_m_k
,
b_g_k_n
,
c_g_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
}
}
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSpace
());
...
@@ -156,6 +214,102 @@ void profile_batched_gemm_impl(int do_verification,
...
@@ -156,6 +214,102 @@ void profile_batched_gemm_impl(int do_verification,
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances
(
gemm_ptrs
);
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances
(
gemm_ptrs
);
}
}
}
}
else
if
constexpr
(
is_same
<
ADataType
,
bhalf_t
>::
value
&&
is_same
<
BDataType
,
bhalf_t
>::
value
&&
is_same
<
CDataType
,
bhalf_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances
(
gemm_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ADataType
,
float
>::
value
&&
is_same
<
BDataType
,
float
>::
value
&&
is_same
<
CDataType
,
float
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances
(
gemm_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ADataType
,
int8_t
>::
value
&&
is_same
<
BDataType
,
int8_t
>::
value
&&
is_same
<
CDataType
,
int8_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances
(
gemm_ptrs
);
}
}
if
(
gemm_ptrs
.
size
()
<=
0
)
if
(
gemm_ptrs
.
size
()
<=
0
)
{
{
...
@@ -218,7 +372,19 @@ void profile_batched_gemm_impl(int do_verification,
...
@@ -218,7 +372,19 @@ void profile_batched_gemm_impl(int do_verification,
{
{
c_device_buf
.
FromDevice
(
c_g_m_n_device_result
.
mData
.
data
());
c_device_buf
.
FromDevice
(
c_g_m_n_device_result
.
mData
.
data
());
check_error
(
c_g_m_n_host_result
,
c_g_m_n_device_result
);
if
constexpr
(
is_same
<
ADataType
,
ck
::
bhalf_t
>::
value
&&
is_same
<
BDataType
,
ck
::
bhalf_t
>::
value
&&
is_same
<
CDataType
,
ck
::
bhalf_t
>::
value
)
{
bf16_to_f32_
(
c_g_m_n_device_result
,
*
c_f32_g_m_n_device_result
);
check_error
(
*
c_f32_g_m_n_host_result
,
*
c_f32_g_m_n_device_result
);
}
else
{
check_error
(
c_g_m_n_host_result
,
c_g_m_n_device_result
);
}
if
(
do_log
)
if
(
do_log
)
{
{
...
...
profiler/include/profile_conv_bwd_data_impl.hpp
View file @
e72c0c43
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F32
=
float
;
using
BF16
=
ushor
t
;
using
BF16
=
ck
::
bhalf_
t
;
using
INT8
=
int8_t
;
using
INT8
=
int8_t
;
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -172,9 +172,9 @@ void profile_conv_bwd_data_impl(int do_verification,
...
@@ -172,9 +172,9 @@ void profile_conv_bwd_data_impl(int do_verification,
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_instance
::
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_instance
::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
(
conv_ptrs
);
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
(
conv_ptrs
);
}
}
else
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
ushor
t
>
&&
else
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
ck
::
bhalf_
t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
ushor
t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
ck
::
bhalf_
t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
ushor
t
>
)
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
ck
::
bhalf_
t
>
)
{
{
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_instance
::
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_instance
::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
conv_ptrs
);
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
conv_ptrs
);
...
...
profiler/include/profile_gemm_bias_2d_impl.hpp
View file @
e72c0c43
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
#include "tensor_layout.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
#include "device_gemm.hpp"
#include "device_gemm
_bias
.hpp"
#include "reference_gemm_bias_2d.hpp"
#include "reference_gemm_bias_2d.hpp"
namespace
ck
{
namespace
ck
{
...
...
profiler/include/profile_gemm_impl.hpp
View file @
e72c0c43
...
@@ -26,16 +26,28 @@ void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNo
...
@@ -26,16 +26,28 @@ void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNo
void
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances
(
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances
(
void
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
(
void
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
...
@@ -45,6 +57,11 @@ void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNo
...
@@ -45,6 +57,11 @@ void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNo
void
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
...
@@ -127,11 +144,6 @@ void profile_gemm_impl(int do_verification,
...
@@ -127,11 +144,6 @@ void profile_gemm_impl(int do_verification,
const
auto
b_element_op
=
BElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
// if(do_verification)
// {
// }
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
());
...
@@ -159,6 +171,9 @@ void profile_gemm_impl(int do_verification,
...
@@ -159,6 +171,9 @@ void profile_gemm_impl(int do_verification,
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances
(
gemm_ptrs
);
}
}
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
...
@@ -174,6 +189,9 @@ void profile_gemm_impl(int do_verification,
...
@@ -174,6 +189,9 @@ void profile_gemm_impl(int do_verification,
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances
(
gemm_ptrs
);
}
}
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
...
@@ -189,6 +207,9 @@ void profile_gemm_impl(int do_verification,
...
@@ -189,6 +207,9 @@ void profile_gemm_impl(int do_verification,
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
}
}
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
...
@@ -204,6 +225,9 @@ void profile_gemm_impl(int do_verification,
...
@@ -204,6 +225,9 @@ void profile_gemm_impl(int do_verification,
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances
(
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances
(
gemm_ptrs
);
}
}
}
}
}
}
...
@@ -291,23 +315,65 @@ void profile_gemm_impl(int do_verification,
...
@@ -291,23 +315,65 @@ void profile_gemm_impl(int do_verification,
is_same
<
CDataType
,
ck
::
bhalf_t
>::
value
)
is_same
<
CDataType
,
ck
::
bhalf_t
>::
value
)
{
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
Column
Major
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
Row
Major
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances
(
gemm_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances
(
gemm_ptrs
);
}
}
}
else
if
constexpr
(
is_same
<
ADataType
,
int8_t
>::
value
&&
is_same
<
BDataType
,
int8_t
>::
value
&&
else
if
constexpr
(
is_same
<
ADataType
,
int8_t
>::
value
&&
is_same
<
BDataType
,
int8_t
>::
value
&&
is_same
<
CDataType
,
int8_t
>::
value
)
is_same
<
CDataType
,
int8_t
>::
value
)
{
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
Column
Major
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
Row
Major
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances
(
gemm_ptrs
);
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances
(
gemm_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances
(
gemm_ptrs
);
}
}
}
if
(
gemm_ptrs
.
size
()
<=
0
)
if
(
gemm_ptrs
.
size
()
<=
0
)
...
...
profiler/include/profile_gemm_reduce_impl.hpp
0 → 100644
View file @
e72c0c43
#pragma once
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_conv.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "element_wise_operation.hpp"
#include "element_wise_reduce_operation.hpp"
#include "device_gemm_reduce.hpp"
#include "reference_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
DeviceGemmReduceNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmReducePtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
ReduceSum
,
ck
::
tensor_operation
::
element_wise
::
ReduceSquareSum
>
;
void
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmReduceNoOpPtr
>&
);
void
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmReduceNoOpPtr
>&
);
void
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmReduceNoOpPtr
>&
);
void
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmReduceNoOpPtr
>&
);
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
namespace
ck
{
namespace
profiler
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
DDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
bool
profile_gemm_reduce_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
int
nrepeat
,
int
M
,
int
N
,
int
K
,
int
StrideA
,
int
StrideB
,
int
StrideC
)
{
bool
pass
=
true
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
is_same
<
decltype
(
layout
),
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
DDataType
>
d0_m_host_result
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
M
)})));
Tensor
<
DDataType
>
d1_m_host_result
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
M
)})));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
DDataType
>
d0_m_device_result
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
M
)})));
Tensor
<
DDataType
>
d1_m_device_result
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
M
)})));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d0_m: "
<<
d0_m_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d1_m: "
<<
d1_m_host_result
.
mDesc
<<
std
::
endl
;
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
std
::
srand
(
0
);
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
},
num_thread
);
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
},
num_thread
);
break
;
default:
std
::
srand
(
0
);
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
},
num_thread
);
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
},
num_thread
);
}
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
D0ReduceOp
=
ck
::
tensor_operation
::
element_wise
::
ReduceSum
;
using
D1ReduceOp
=
ck
::
tensor_operation
::
element_wise
::
ReduceSquareSum
;
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
const
auto
d0_reduce_op
=
D0ReduceOp
{};
const
auto
d1_reduce_op
=
D1ReduceOp
{};
if
(
do_verification
)
{
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
float
d0_acc
=
d0_reduce_op
.
GetReduceZeroValue
();
float
d1_acc
=
d1_reduce_op
.
GetReduceZeroValue
();
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
d0_reduce_op
.
Reduce
(
d0_acc
,
c_m_n_host_result
(
m
,
n
));
d1_reduce_op
.
Reduce
(
d1_acc
,
c_m_n_host_result
(
m
,
n
));
}
d0_m_host_result
(
m
)
=
d0_acc
;
d1_m_host_result
(
m
)
=
d1_acc
;
}
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
d0_device_buf
(
sizeof
(
DDataType
)
*
d0_m_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
d1_device_buf
(
sizeof
(
DDataType
)
*
d1_m_device_result
.
mDesc
.
GetElementSpace
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmReduceNoOpPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances
(
gemm_ptrs
);
}
}
if
(
gemm_ptrs
.
size
()
<=
0
)
{
throw
std
::
runtime_error
(
"wrong! no device GEMM instance found"
);
}
std
::
string
best_gemm_name
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
// profile device GEMM instances
for
(
auto
&
gemm_ptr
:
gemm_ptrs
)
{
auto
argument_ptr
=
gemm_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
d0_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
d1_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
d0_reduce_op
,
d1_reduce_op
);
auto
invoker_ptr
=
gemm_ptr
->
MakeInvokerPointer
();
if
(
gemm_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
// warm up
invoker_ptr
->
Run
(
argument_ptr
.
get
());
// timing
float
total_time
=
0
;
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
// init DO, D1 to 0
d0_device_buf
.
SetZero
();
d1_device_buf
.
SetZero
();
KernelTimer
timer
;
timer
.
Start
();
invoker_ptr
->
Run
(
argument_ptr
.
get
());
timer
.
End
();
total_time
+=
timer
.
GetElapsedTime
();
}
float
ave_time
=
total_time
/
nrepeat
;
std
::
string
gemm_name
=
gemm_ptr
->
GetTypeString
();
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
M
+
sizeof
(
CDataType
)
*
M
*
N
+
sizeof
(
CDataType
)
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
best_gemm_name
=
gemm_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
if
(
do_verification
)
{
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
d0_device_buf
.
FromDevice
(
d0_m_device_result
.
mData
.
data
());
d1_device_buf
.
FromDevice
(
d1_m_device_result
.
mData
.
data
());
float
c_error
=
check_error
(
c_m_n_host_result
,
c_m_n_device_result
);
float
d0_error
=
check_error
(
d0_m_host_result
,
d0_m_device_result
);
float
d1_error
=
check_error
(
d1_m_host_result
,
d1_m_device_result
);
pass
=
pass
&&
(
c_error
<
1E-6
);
pass
=
pass
&&
(
d0_error
<
1E-6
);
pass
=
pass
&&
(
d1_error
<
1E-6
);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a_m_k
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b_k_n
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host: "
,
c_m_n_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
c_m_n_device_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"d0_host: "
,
d0_m_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"d0_device: "
,
d0_m_device_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"d1_host: "
,
d1_m_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"d1_device: "
,
d1_m_device_result
.
mData
,
","
)
<<
std
::
endl
;
}
}
}
else
{
std
::
cout
<<
"does not support this GEMM problem"
<<
std
::
endl
;
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_gemm_name
<<
std
::
endl
;
return
pass
;
}
}
// namespace profiler
}
// namespace ck
profiler/include/profile_grouped_gemm_impl.hpp
0 → 100644
View file @
e72c0c43
#pragma once
#include <iomanip>
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_conv.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "element_wise_operation.hpp"
#include "device_gemm.hpp"
#include "reference_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_grouped_gemm_instance
{
using
DeviceGroupedGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
}
// namespace device_grouped_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
namespace
ck
{
namespace
profiler
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
void
profile_grouped_gemm_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
int
nrepeat
,
std
::
vector
<
int
>
Ms
,
std
::
vector
<
int
>
Ns
,
std
::
vector
<
int
>
Ks
,
std
::
vector
<
int
>
StrideAs
,
std
::
vector
<
int
>
StrideBs
,
std
::
vector
<
int
>
StrideCs
)
{
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
is_same
<
decltype
(
layout
),
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
int
group_count
=
Ms
.
size
();
if
(
!
(
group_count
==
Ns
.
size
()
&&
group_count
==
Ks
.
size
()
&&
group_count
==
StrideAs
.
size
()
&&
group_count
==
StrideBs
.
size
()
&&
group_count
==
StrideCs
.
size
()))
{
throw
std
::
runtime_error
(
"wrong! inconsistent M/N/Ks, StrideA/B/Cs size
\n
"
);
}
std
::
vector
<
Tensor
<
ADataType
>>
a_m_k
;
std
::
vector
<
Tensor
<
BDataType
>>
b_k_n
;
std
::
vector
<
Tensor
<
CDataType
>>
c_m_n_device_results
;
for
(
int
i
=
0
;
i
<
Ms
.
size
();
i
++
)
{
a_m_k
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ks
[
i
],
StrideAs
[
i
],
ALayout
{})));
b_k_n
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
Ks
[
i
],
Ns
[
i
],
StrideBs
[
i
],
BLayout
{})));
c_m_n_device_results
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
std
::
cout
<<
"group: "
<<
i
<<
" a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
", b_k_n["
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
", c_m_n_device_results["
<<
i
<<
"]:"
<<
c_m_n_device_results
[
i
].
mDesc
<<
std
::
endl
;
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_m_k
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
},
num_thread
);
b_k_n
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
},
num_thread
);
break
;
default:
a_m_k
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
},
num_thread
);
b_k_n
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
},
num_thread
);
}
c_m_n_device_results
[
i
].
GenerateTensorValue
(
GeneratorTensor_0
<
CDataType
>
{},
num_thread
);
}
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
// if(do_verification)
// {
// }
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
std
::
vector
<
DeviceMemPtr
>
a_device_buf
,
b_device_buf
,
c_device_buf
;
a_device_buf
.
reserve
(
group_count
);
b_device_buf
.
reserve
(
group_count
);
c_device_buf
.
reserve
(
group_count
);
std
::
vector
<
const
void
*>
p_a
,
p_b
;
std
::
vector
<
void
*>
p_c
;
p_a
.
reserve
(
group_count
);
p_b
.
reserve
(
group_count
);
p_c
.
reserve
(
group_count
);
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmShape
>
gemm_shapes
;
gemm_shapes
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
a_device_buf
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_m_k
[
i
].
mDesc
.
GetElementSpace
()));
b_device_buf
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
b_k_n
[
i
].
mDesc
.
GetElementSpace
()));
c_device_buf
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
CDataType
)
*
c_m_n_device_results
[
i
].
mDesc
.
GetElementSpace
()));
a_device_buf
[
i
]
->
ToDevice
(
a_m_k
[
i
].
mData
.
data
());
b_device_buf
[
i
]
->
ToDevice
(
b_k_n
[
i
].
mData
.
data
());
c_device_buf
[
i
]
->
ToDevice
(
c_m_n_device_results
[
i
].
mData
.
data
());
gemm_shapes
.
push_back
({
Ms
[
i
],
Ns
[
i
],
Ks
[
i
],
StrideAs
[
i
],
StrideBs
[
i
],
StrideCs
[
i
]});
p_a
.
push_back
(
a_device_buf
[
i
]
->
GetDeviceBuffer
());
p_b
.
push_back
(
b_device_buf
[
i
]
->
GetDeviceBuffer
());
p_c
.
push_back
(
c_device_buf
[
i
]
->
GetDeviceBuffer
());
}
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_grouped_gemm_instance
::
DeviceGroupedGemmNoOpPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_grouped_gemm_instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_grouped_gemm_instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_grouped_gemm_instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_grouped_gemm_instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
gemm_ptrs
);
}
}
if
(
gemm_ptrs
.
size
()
<=
0
)
{
throw
std
::
runtime_error
(
"wrong! no device GEMM instance found"
);
}
std
::
string
best_gemm_name
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
// profile device GEMM instances
for
(
auto
&
gemm_ptr
:
gemm_ptrs
)
{
auto
argument_ptr
=
gemm_ptr
->
MakeArgumentPointer
(
p_a
,
p_b
,
p_c
,
gemm_shapes
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
auto
invoker_ptr
=
gemm_ptr
->
MakeInvokerPointer
();
if
(
gemm_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
std
::
string
gemm_name
=
gemm_ptr
->
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
nrepeat
);
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
flop
+=
std
::
size_t
(
2
)
*
Ms
[
i
]
*
Ns
[
i
]
*
Ks
[
i
];
num_btype
+=
sizeof
(
ADataType
)
*
Ms
[
i
]
*
Ks
[
i
]
+
sizeof
(
BDataType
)
*
Ks
[
i
]
*
Ns
[
i
]
+
sizeof
(
CDataType
)
*
Ms
[
i
]
*
Ns
[
i
];
}
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
best_gemm_name
=
gemm_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
if
(
do_verification
)
{
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
c_device_buf
[
i
]
->
FromDevice
(
c_m_n_device_results
[
i
].
mData
.
data
());
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{}));
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
[
i
],
b_k_n
[
i
],
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
check_error
(
c_m_n_host_result
,
c_m_n_device_results
[
i
]);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a_m_k
[
i
].
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b_k_n
[
i
].
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
c_m_n_device_results
[
i
].
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host : "
,
c_m_n_host_result
.
mData
,
","
)
<<
std
::
endl
;
}
}
}
}
else
{
std
::
cout
<<
"does not support this GEMM problem"
<<
std
::
endl
;
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_gemm_name
<<
std
::
endl
;
}
// namespace profiler
}
// namespace profiler
}
// namespace ck
profiler/include/profile_reduce_impl.hpp
View file @
e72c0c43
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#include "device_reduce.hpp"
#include "device_reduce.hpp"
#include "device_reduce_instance.hpp"
#include "device_reduce_instance.hpp"
#include "reduction_enums.hpp"
#include "reduction_enums.hpp"
#include "host_
generic_
reduction.hpp"
#include "host_reduction.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -20,34 +20,43 @@ struct ReduceDescription
...
@@ -20,34 +20,43 @@ struct ReduceDescription
};
};
using
reduce_description_instances
=
std
::
tuple
<
ReduceDescription
<
4
,
3
,
0
,
0
,
0
>
,
// for ADD
using
reduce_description_instances
=
std
::
tuple
<
ReduceDescription
<
4
,
3
,
0
,
0
,
0
>
,
// for ADD
ReduceDescription
<
4
,
4
,
0
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
0
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
0
,
0
,
0
>
,
ReduceDescription
<
2
,
1
,
0
,
0
,
0
>
,
ReduceDescription
<
2
,
1
,
0
,
0
,
0
>
,
ReduceDescription
<
4
,
3
,
5
,
0
,
0
>
,
// for AVG
ReduceDescription
<
4
,
3
,
5
,
0
,
0
>
,
// for AVG
ReduceDescription
<
4
,
4
,
5
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
5
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
5
,
0
,
0
>
,
ReduceDescription
<
2
,
1
,
5
,
0
,
0
>
,
ReduceDescription
<
2
,
1
,
5
,
0
,
0
>
,
ReduceDescription
<
4
,
3
,
7
,
0
,
0
>
,
// for NORM2
ReduceDescription
<
4
,
3
,
7
,
0
,
0
>
,
// for NORM2
ReduceDescription
<
4
,
4
,
7
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
7
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
7
,
0
,
0
>
,
ReduceDescription
<
2
,
1
,
7
,
0
,
0
>
,
ReduceDescription
<
2
,
1
,
7
,
0
,
0
>
,
ReduceDescription
<
4
,
3
,
2
,
0
,
0
>
,
// for MIN
ReduceDescription
<
4
,
3
,
2
,
0
,
0
>
,
// for MIN
ReduceDescription
<
4
,
4
,
2
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
2
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
2
,
0
,
0
>
,
ReduceDescription
<
2
,
1
,
2
,
0
,
0
>
,
ReduceDescription
<
2
,
1
,
2
,
0
,
0
>
,
ReduceDescription
<
4
,
3
,
3
,
0
,
0
>
,
// for MAX
ReduceDescription
<
4
,
3
,
3
,
0
,
0
>
,
// for MAX
ReduceDescription
<
4
,
4
,
3
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
3
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
3
,
0
,
0
>
,
ReduceDescription
<
2
,
1
,
3
,
0
,
0
>
,
ReduceDescription
<
2
,
1
,
3
,
0
,
0
>
,
ReduceDescription
<
4
,
3
,
4
,
0
,
0
>
,
// for AMAX
ReduceDescription
<
4
,
3
,
4
,
0
,
0
>
,
// for AMAX
ReduceDescription
<
4
,
4
,
4
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
4
,
0
,
0
>
,
ReduceDescription
<
4
,
1
,
4
,
0
,
0
>
,
ReduceDescription
<
2
,
1
,
4
,
0
,
0
>
,
ReduceDescription
<
2
,
1
,
4
,
0
,
0
>
,
ReduceDescription
<
4
,
3
,
2
,
0
,
1
>
,
// for MIN
ReduceDescription
<
4
,
3
,
2
,
0
,
1
>
,
// for MIN
ReduceDescription
<
4
,
4
,
2
,
0
,
1
>
,
ReduceDescription
<
4
,
1
,
2
,
0
,
1
>
,
ReduceDescription
<
4
,
1
,
2
,
0
,
1
>
,
ReduceDescription
<
2
,
1
,
2
,
0
,
1
>
,
ReduceDescription
<
2
,
1
,
2
,
0
,
1
>
,
ReduceDescription
<
4
,
3
,
3
,
0
,
1
>
,
// for MAX
ReduceDescription
<
4
,
3
,
3
,
0
,
1
>
,
// for MAX
ReduceDescription
<
4
,
4
,
3
,
0
,
1
>
,
ReduceDescription
<
4
,
1
,
3
,
0
,
1
>
,
ReduceDescription
<
4
,
1
,
3
,
0
,
1
>
,
ReduceDescription
<
2
,
1
,
3
,
0
,
1
>
,
ReduceDescription
<
2
,
1
,
3
,
0
,
1
>
,
ReduceDescription
<
4
,
3
,
4
,
0
,
1
>
,
// for AMAX
ReduceDescription
<
4
,
3
,
4
,
0
,
1
>
,
// for AMAX
ReduceDescription
<
4
,
4
,
4
,
0
,
1
>
,
ReduceDescription
<
4
,
1
,
4
,
0
,
1
>
,
ReduceDescription
<
4
,
1
,
4
,
0
,
1
>
,
ReduceDescription
<
2
,
1
,
4
,
0
,
1
>>
;
ReduceDescription
<
2
,
1
,
4
,
0
,
1
>>
;
...
@@ -122,16 +131,16 @@ static void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems)
...
@@ -122,16 +131,16 @@ static void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems)
};
};
// map the data type used by the GPU kernels to the corresponding type used by the host codes
// map the data type used by the GPU kernels to the corresponding type used by the host codes
template
<
typename
inData
Type
>
template
<
typename
In
Type
>
struct
type_mapping
struct
type_mapping
{
{
using
o
ut
Data
Type
=
inData
Type
;
using
O
utType
=
In
Type
;
};
};
template
<
>
template
<
>
struct
type_mapping
<
ck
::
half_t
>
struct
type_mapping
<
ck
::
half_t
>
{
{
using
o
ut
Data
Type
=
half_float
::
half
;
using
O
utType
=
half_float
::
half
;
};
};
template
<
typename
InDataType
,
template
<
typename
InDataType
,
...
@@ -187,7 +196,26 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -187,7 +196,26 @@ void profile_reduce_impl_impl(bool do_verification,
constexpr
bool
invalid_reduce_3
=
constexpr
bool
invalid_reduce_3
=
(
!
op_support_indices
&&
IndicesOpt
!=
ReduceTensorIndices_t
::
NO_INDICES
);
(
!
op_support_indices
&&
IndicesOpt
!=
ReduceTensorIndices_t
::
NO_INDICES
);
constexpr
bool
invalid_reduce
=
(
invalid_reduce_1
||
invalid_reduce_2
||
invalid_reduce_3
);
// 1) If InDataType is int8_t, must use int8_t as AccDataType for indexable reduction operations
// 2) If InDataType is int8_t, must use int32_t as AccDataType for non-indexable reduction
// operations
constexpr
bool
invalid_reduce_4
=
std
::
is_same
<
InDataType
,
int8_t
>::
value
&&
((
!
op_support_indices
&&
!
std
::
is_same
<
AccDataType
,
int32_t
>::
value
)
||
(
op_support_indices
&&
!
std
::
is_same
<
AccDataType
,
int8_t
>::
value
));
// 1) If InDataType is int8_t, the supported operation must be either indexable operations or
// ADD/AVG
constexpr
bool
invalid_reduce_5
=
std
::
is_same
<
InDataType
,
int8_t
>::
value
&&
(
!
op_support_indices
&&
ReduceOpId
!=
ReduceTensorOp_t
::
ADD
&&
ReduceOpId
!=
ReduceTensorOp_t
::
AVG
);
// 1) If InDataType is bhalf_t, must use float as AccDataType for all reduction operations
constexpr
bool
invalid_reduce_6
=
std
::
is_same
<
InDataType
,
bhalf_t
>::
value
&&
!
std
::
is_same
<
AccDataType
,
float
>::
value
;
constexpr
bool
invalid_reduce
=
(
invalid_reduce_1
||
invalid_reduce_2
||
invalid_reduce_3
||
invalid_reduce_4
||
invalid_reduce_5
||
invalid_reduce_6
);
if
constexpr
(
!
invalid_reduce
)
if
constexpr
(
!
invalid_reduce
)
{
{
...
@@ -205,8 +233,8 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -205,8 +233,8 @@ void profile_reduce_impl_impl(bool do_verification,
Tensor
<
OutDataType
>
out_ref
(
outLengths
);
Tensor
<
OutDataType
>
out_ref
(
outLengths
);
Tensor
<
OutDataType
>
out
(
outLengths
);
Tensor
<
OutDataType
>
out
(
outLengths
);
Tensor
<
int
>
out_indices_ref
(
outLengths
);
Tensor
<
int
32_t
>
out_indices_ref
(
outLengths
);
Tensor
<
int
>
out_indices
(
outLengths
);
Tensor
<
int
32_t
>
out_indices
(
outLengths
);
auto
inStrides
=
in
.
mDesc
.
GetStrides
();
auto
inStrides
=
in
.
mDesc
.
GetStrides
();
auto
outStrides
=
out
.
mDesc
.
GetStrides
();
auto
outStrides
=
out
.
mDesc
.
GetStrides
();
...
@@ -220,20 +248,22 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -220,20 +248,22 @@ void profile_reduce_impl_impl(bool do_verification,
{
{
switch
(
init_method
)
switch
(
init_method
)
{
{
case
0
:
case
0
:
break
;
in
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{},
num_thread
);
case
1
:
in
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{
1
},
num_thread
);
if
(
beta
!=
0.0
f
)
if
(
beta
!=
0.0
f
)
out_ref
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{},
num_thread
);
out_ref
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{
1
},
num_thread
);
break
;
break
;
case
1
:
case
2
:
in
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
},
num_thread
);
in
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
},
num_thread
);
if
(
beta
!=
0.0
f
)
if
(
beta
!=
0.0
f
)
out_ref
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
},
num_thread
);
out_ref
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
},
num_thread
);
break
;
break
;
default:
default:
in
.
GenerateTensorValue
(
GeneratorTensor_
2
<
InDataType
>
{
1
,
5
},
num_thread
);
in
.
GenerateTensorValue
(
GeneratorTensor_
3
<
InDataType
>
{
-
5.0
,
5
.0
},
num_thread
);
if
(
beta
!=
0.0
f
)
if
(
beta
!=
0.0
f
)
out_ref
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
1
,
5
},
num_thread
);
out_ref
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
-
5.0
,
5.0
},
num_thread
);
}
}
if
(
beta
!=
0.0
f
)
if
(
beta
!=
0.0
f
)
...
@@ -306,6 +336,7 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -306,6 +336,7 @@ void profile_reduce_impl_impl(bool do_verification,
IndicesOpt
>
(
reduce0_ptrs
);
IndicesOpt
>
(
reduce0_ptrs
);
if
constexpr
(
use_atomic_add
)
if
constexpr
(
use_atomic_add
)
{
add_device_reduce_instance_multiblock_atomic_add
<
InDataType
,
add_device_reduce_instance_multiblock_atomic_add
<
InDataType
,
AccDataType
,
AccDataType
,
OutDataType
,
OutDataType
,
...
@@ -314,7 +345,9 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -314,7 +345,9 @@ void profile_reduce_impl_impl(bool do_verification,
ReduceOpId
,
ReduceOpId
,
NanOpt
,
NanOpt
,
IndicesOpt
>
(
reduce0_ptrs
);
IndicesOpt
>
(
reduce0_ptrs
);
}
else
else
{
add_device_reduce_instance_multiblock_partial_reduce
<
InDataType
,
add_device_reduce_instance_multiblock_partial_reduce
<
InDataType
,
AccDataType
,
AccDataType
,
OutDataType
,
OutDataType
,
...
@@ -323,9 +356,11 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -323,9 +356,11 @@ void profile_reduce_impl_impl(bool do_verification,
ReduceOpId
,
ReduceOpId
,
NanOpt
,
NanOpt
,
IndicesOpt
>
(
reduce1_ptrs
);
IndicesOpt
>
(
reduce1_ptrs
);
};
// used for secondary reduction
// used for secondary reduction
if
constexpr
(
!
use_atomic_add
)
if
constexpr
(
!
use_atomic_add
)
{
add_device_reduce_instance_blockwise_second_call
<
AccDataType
,
add_device_reduce_instance_blockwise_second_call
<
AccDataType
,
AccDataType
,
AccDataType
,
OutDataType
,
OutDataType
,
...
@@ -334,6 +369,7 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -334,6 +369,7 @@ void profile_reduce_impl_impl(bool do_verification,
ReduceOpId
,
ReduceOpId
,
NanOpt
,
NanOpt
,
IndicesOpt
>
(
reduce2_ptrs
);
IndicesOpt
>
(
reduce2_ptrs
);
};
if
(
reduce0_ptrs
.
empty
()
&&
reduce1_ptrs
.
empty
())
if
(
reduce0_ptrs
.
empty
()
&&
reduce1_ptrs
.
empty
())
{
{
...
@@ -342,17 +378,24 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -342,17 +378,24 @@ void profile_reduce_impl_impl(bool do_verification,
if
(
do_verification
)
if
(
do_verification
)
{
{
using
hInType
=
typename
type_mapping
<
InDataType
>::
outDataType
;
using
HostInDataType
=
typename
type_mapping
<
InDataType
>::
OutType
;
using
hOutType
=
typename
type_mapping
<
OutDataType
>::
outDataType
;
using
HostOutDataType
=
typename
type_mapping
<
OutDataType
>::
OutType
;
using
hCompType
=
typename
type_mapping
<
AccDataType
>::
outDataType
;
using
HostAccDataType
=
typename
type_mapping
<
AccDataType
>::
OutType
;
ReductionHost
<
hInType
,
hCompType
,
hOutType
,
ReduceOpId
,
PropagateNan
,
NeedIndices
>
ReductionHost
<
HostInDataType
,
HostAccDataType
,
HostOutDataType
,
ReduceOpId
,
Rank
,
NumReduceDim
,
PropagateNan
,
NeedIndices
>
hostReduce
(
in
.
mDesc
,
out_ref
.
mDesc
,
invariantDims
,
reduceDims
);
hostReduce
(
in
.
mDesc
,
out_ref
.
mDesc
,
invariantDims
,
reduceDims
);
hostReduce
.
Run
(
alpha
,
hostReduce
.
Run
(
alpha
,
reinterpret_cast
<
const
hIn
Type
*>
(
in
.
mData
.
data
()),
reinterpret_cast
<
const
HostInData
Type
*>
(
in
.
mData
.
data
()),
beta
,
beta
,
reinterpret_cast
<
hOut
Type
*>
(
out_ref
.
mData
.
data
()),
reinterpret_cast
<
HostOutData
Type
*>
(
out_ref
.
mData
.
data
()),
out_indices_ref
.
mData
.
data
());
out_indices_ref
.
mData
.
data
());
};
};
...
@@ -363,24 +406,27 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -363,24 +406,27 @@ void profile_reduce_impl_impl(bool do_verification,
for
(
auto
&
reduce_ptr
:
reduce0_ptrs
)
for
(
auto
&
reduce_ptr
:
reduce0_ptrs
)
{
{
auto
wsSizeInBytes
=
reduce_ptr
->
GetWorkspaceSizeInBytes
(
i_inLengths
);
auto
wsSizeInBytes
=
reduce_ptr
->
GetWorkspaceSizeInBytes
(
i_inLengths
,
reduceDims
);
DeviceMem
ws_dev
(
wsSizeInBytes
);
DeviceMem
ws_dev
(
wsSizeInBytes
);
auto
argument_ptr
=
reduce_ptr
->
MakeArgumentPointer
(
InElementwiseOperation_0
in_elementwise_op_0
(
static_cast
<
int32_t
>
(
reduce_total_length
));
i_inLengths
,
AccElementwiseOperation_0
acc_elementwise_op_0
(
i_inStrides
,
static_cast
<
int32_t
>
(
reduce_total_length
));
i_outLengths
,
i_outStrides
,
auto
argument_ptr
=
reduce_ptr
->
MakeArgumentPointer
(
i_inLengths
,
reduceDims
,
i_inStrides
,
alpha
,
i_outLengths
,
beta
,
i_outStrides
,
in_dev
.
GetDeviceBuffer
(),
reduceDims
,
out_dev
.
GetDeviceBuffer
(),
alpha
,
out_indices_dev
.
GetDeviceBuffer
(),
beta
,
ws_dev
.
GetDeviceBuffer
(),
in_dev
.
GetDeviceBuffer
(),
InElementwiseOperation_0
{
static_cast
<
int32_t
>
(
reduce_total_length
)},
out_dev
.
GetDeviceBuffer
(),
AccElementwiseOperation_0
{
static_cast
<
int32_t
>
(
reduce_total_length
)});
out_indices_dev
.
GetDeviceBuffer
(),
ws_dev
.
GetDeviceBuffer
(),
in_elementwise_op_0
,
acc_elementwise_op_0
);
if
(
!
reduce_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
!
reduce_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
continue
;
continue
;
...
@@ -445,24 +491,27 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -445,24 +491,27 @@ void profile_reduce_impl_impl(bool do_verification,
for
(
auto
&
reduce_ptr
:
reduce1_ptrs
)
for
(
auto
&
reduce_ptr
:
reduce1_ptrs
)
{
{
auto
wsSizeInBytes
=
reduce_ptr
->
GetWorkspaceSizeInBytes
(
i_inLengths
);
auto
wsSizeInBytes
=
reduce_ptr
->
GetWorkspaceSizeInBytes
(
i_inLengths
,
reduceDims
);
DeviceMem
ws_dev
(
wsSizeInBytes
);
DeviceMem
ws_dev
(
wsSizeInBytes
);
auto
argument_ptr
=
reduce_ptr
->
MakeArgumentPointer
(
InElementwiseOperation_1
in_elementwise_op_1
(
static_cast
<
int32_t
>
(
reduce_total_length
));
i_inLengths
,
AccElementwiseOperation_1
acc_elementwise_op_1
(
i_inStrides
,
static_cast
<
int32_t
>
(
reduce_total_length
));
i_outLengths
,
i_outStrides
,
auto
argument_ptr
=
reduce_ptr
->
MakeArgumentPointer
(
i_inLengths
,
reduceDims
,
i_inStrides
,
alpha
,
i_outLengths
,
beta
,
i_outStrides
,
in_dev
.
GetDeviceBuffer
(),
reduceDims
,
out_dev
.
GetDeviceBuffer
(),
alpha
,
out_indices_dev
.
GetDeviceBuffer
(),
beta
,
ws_dev
.
GetDeviceBuffer
(),
in_dev
.
GetDeviceBuffer
(),
InElementwiseOperation_1
{
static_cast
<
int32_t
>
(
reduce_total_length
)},
out_dev
.
GetDeviceBuffer
(),
AccElementwiseOperation_1
{
static_cast
<
int32_t
>
(
reduce_total_length
)});
out_indices_dev
.
GetDeviceBuffer
(),
ws_dev
.
GetDeviceBuffer
(),
in_elementwise_op_1
,
acc_elementwise_op_1
);
if
(
!
reduce_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
!
reduce_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
continue
;
continue
;
...
@@ -482,20 +531,25 @@ void profile_reduce_impl_impl(bool do_verification,
...
@@ -482,20 +531,25 @@ void profile_reduce_impl_impl(bool do_verification,
for
(
auto
&
reduce2_ptr
:
reduce2_ptrs
)
for
(
auto
&
reduce2_ptr
:
reduce2_ptrs
)
{
{
auto
argument2_ptr
=
reduce2_ptr
->
MakeArgumentPointer
(
InElementwiseOperation_2
in_elementwise_op_2
(
inLengths2
,
static_cast
<
int32_t
>
(
reduce_total_length
));
inStrides2
,
AccElementwiseOperation_2
acc_elementwise_op_2
(
i_outLengths
,
static_cast
<
int32_t
>
(
reduce_total_length
));
i_outStrides
,
reduceDims
,
auto
argument2_ptr
=
alpha
,
reduce2_ptr
->
MakeArgumentPointer
(
inLengths2
,
beta
,
inStrides2
,
ws_dev
.
GetDeviceBuffer
(),
i_outLengths
,
out_dev
.
GetDeviceBuffer
(),
i_outStrides
,
out_indices_dev
.
GetDeviceBuffer
(),
reduceDims
,
ws_dev
.
GetDeviceBuffer
(),
alpha
,
InElementwiseOperation_2
{
static_cast
<
int32_t
>
(
reduce_total_length
)},
beta
,
AccElementwiseOperation_2
{
static_cast
<
int32_t
>
(
reduce_total_length
)});
ws_dev
.
GetDeviceBuffer
(),
out_dev
.
GetDeviceBuffer
(),
out_indices_dev
.
GetDeviceBuffer
(),
ws_dev
.
GetDeviceBuffer
(),
in_elementwise_op_2
,
acc_elementwise_op_2
);
if
(
!
reduce2_ptr
->
IsSupportedArgument
(
argument2_ptr
.
get
()))
if
(
!
reduce2_ptr
->
IsSupportedArgument
(
argument2_ptr
.
get
()))
continue
;
continue
;
...
...
profiler/src/README.md
View file @
e72c0c43
...
@@ -67,8 +67,8 @@ Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s
...
@@ -67,8 +67,8 @@ Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s
#arg8: print matrix value (0=no, 1=yes)
#arg8: print matrix value (0=no, 1=yes)
#arg9: run kernel # of times (>1)
#arg9: run kernel # of times (>1)
#arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx
#arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx
##################### op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads
##################### op
datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads
./profiler/ckProfiler conv
1 1 1 1 1 1 0 5 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
./profiler/ckProfiler conv
_fwd
1 1 1 1 1 1 0 5 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
```
```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
...
...
profiler/src/profile_batched_gemm.cpp
View file @
e72c0c43
#include <cstdint>
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
#include <initializer_list>
#include <initializer_list>
...
@@ -15,7 +16,7 @@
...
@@ -15,7 +16,7 @@
#include "device_batched_gemm_xdl.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "profile_batched_gemm_impl.hpp"
#include "profile_batched_gemm_impl.hpp"
enum
GemmMatrixLayout
enum
struct
GemmMatrixLayout
{
{
MK_KN_MN
,
// 0
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
MK_NK_MN
,
// 1
...
@@ -27,10 +28,12 @@ enum GemmMatrixLayout
...
@@ -27,10 +28,12 @@ enum GemmMatrixLayout
KM_NK_NM
,
// 7
KM_NK_NM
,
// 7
};
};
enum
GemmDataType
enum
struct
GemmDataType
{
{
F32_F32_F32
,
// 0
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
F16_F16_F16
,
// 1
BF16_BF16_BF16
,
// 2
INT8_INT8_INT8
,
// 3
};
};
int
profile_batched_gemm
(
int
argc
,
char
*
argv
[])
int
profile_batched_gemm
(
int
argc
,
char
*
argv
[])
...
@@ -38,7 +41,7 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -38,7 +41,7 @@ int profile_batched_gemm(int argc, char* argv[])
if
(
!
(
argc
==
15
))
if
(
!
(
argc
==
15
))
{
{
printf
(
"arg1: tensor operation (batched_gemm: Batched GEMM)
\n
"
);
printf
(
"arg1: tensor operation (batched_gemm: Batched GEMM)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16
, 2: bf16, 3: int8
)
\n
"
);
printf
(
"arg3: matrix layout (0: A[g, m, k] * B[g, k, n] = C[g, m, n];
\n
"
);
printf
(
"arg3: matrix layout (0: A[g, m, k] * B[g, k, n] = C[g, m, n];
\n
"
);
printf
(
" 1: A[g, m, k] * B[g, n, k] = C[g, m, n];
\n
"
);
printf
(
" 1: A[g, m, k] * B[g, n, k] = C[g, m, n];
\n
"
);
printf
(
" 2: A[g, k, m] * B[g, k, n] = C[g, m, n];
\n
"
);
printf
(
" 2: A[g, k, m] * B[g, k, n] = C[g, m, n];
\n
"
);
...
@@ -51,8 +54,8 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -51,8 +54,8 @@ int profile_batched_gemm(int argc, char* argv[])
exit
(
1
);
exit
(
1
);
}
}
const
int
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
...
@@ -146,6 +149,241 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -146,6 +149,241 @@ int profile_batched_gemm(int argc, char* argv[])
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
(
StrideC
<
0
)
?
N
:
StrideC
);
}
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
BatchCount
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
BatchCount
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
BatchCount
);
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
BatchCount
);
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
BatchCount
);
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
BatchCount
);
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
ck
::
profiler
::
profile_batched_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
BatchCount
);
}
else
else
{
{
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
...
...
profiler/src/profile_conv_bwd_data.cpp
View file @
e72c0c43
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include <half.hpp>
#include <half.hpp>
#include "profile_conv_bwd_data_impl.hpp"
#include "profile_conv_bwd_data_impl.hpp"
enum
ConvDataType
enum
struct
ConvDataType
{
{
F32_F32_F32
,
// 0
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
F16_F16_F16
,
// 1
...
@@ -14,19 +14,19 @@ enum ConvDataType
...
@@ -14,19 +14,19 @@ enum ConvDataType
INT8_INT8_INT8
,
// 3
INT8_INT8_INT8
,
// 3
};
};
enum
ConvInputLayout
enum
struct
ConvInputLayout
{
{
NCHW
,
// 0
NCHW
,
// 0
NHWC
,
// 1
NHWC
,
// 1
};
};
enum
ConvWeightLayout
enum
struct
ConvWeightLayout
{
{
KCYX
,
// 0
KCYX
,
// 0
KYXC
,
// 1
KYXC
,
// 1
};
};
enum
ConvOutputLayout
enum
struct
ConvOutputLayout
{
{
NKHW
,
// 0
NKHW
,
// 0
NHWK
,
// 1
NHWK
,
// 1
...
@@ -50,10 +50,10 @@ int profile_conv_bwd_data(int argc, char* argv[])
...
@@ -50,10 +50,10 @@ int profile_conv_bwd_data(int argc, char* argv[])
exit
(
1
);
exit
(
1
);
}
}
const
int
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
int
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
auto
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
int
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
auto
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
...
...
profiler/src/profile_conv_fwd.cpp
View file @
e72c0c43
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include <half.hpp>
#include <half.hpp>
#include "profile_conv_fwd_impl.hpp"
#include "profile_conv_fwd_impl.hpp"
enum
ConvDataType
enum
struct
ConvDataType
{
{
F32_F32_F32
,
// 0
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
F16_F16_F16
,
// 1
...
@@ -14,19 +14,19 @@ enum ConvDataType
...
@@ -14,19 +14,19 @@ enum ConvDataType
INT8_INT8_INT8
,
// 3
INT8_INT8_INT8
,
// 3
};
};
enum
ConvInputLayout
enum
struct
ConvInputLayout
{
{
NCHW
,
// 0
NCHW
,
// 0
NHWC
,
// 1
NHWC
,
// 1
};
};
enum
ConvWeightLayout
enum
struct
ConvWeightLayout
{
{
KCYX
,
// 0
KCYX
,
// 0
KYXC
,
// 1
KYXC
,
// 1
};
};
enum
ConvOutputLayout
enum
struct
ConvOutputLayout
{
{
NKHW
,
// 0
NKHW
,
// 0
NHWK
,
// 1
NHWK
,
// 1
...
@@ -50,10 +50,10 @@ int profile_conv_fwd(int argc, char* argv[])
...
@@ -50,10 +50,10 @@ int profile_conv_fwd(int argc, char* argv[])
exit
(
1
);
exit
(
1
);
}
}
const
int
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
int
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
auto
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
int
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
auto
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
...
...
profiler/src/profile_conv_fwd_bias_relu.cpp
View file @
e72c0c43
...
@@ -6,25 +6,25 @@
...
@@ -6,25 +6,25 @@
#include <half.hpp>
#include <half.hpp>
#include "profile_conv_fwd_bias_relu_impl.hpp"
#include "profile_conv_fwd_bias_relu_impl.hpp"
enum
ConvDataType
enum
struct
ConvDataType
{
{
F32_F32_F32
,
// 0
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
F16_F16_F16
,
// 1
};
};
enum
ConvInputLayout
enum
struct
ConvInputLayout
{
{
NCHW
,
// 0
NCHW
,
// 0
NHWC
,
// 1
NHWC
,
// 1
};
};
enum
ConvWeightLayout
enum
struct
ConvWeightLayout
{
{
KCYX
,
// 0
KCYX
,
// 0
KYXC
,
// 1
KYXC
,
// 1
};
};
enum
ConvOutputLayout
enum
struct
ConvOutputLayout
{
{
NKHW
,
// 0
NKHW
,
// 0
NHWK
,
// 1
NHWK
,
// 1
...
@@ -48,10 +48,10 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
...
@@ -48,10 +48,10 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
exit
(
1
);
exit
(
1
);
}
}
const
int
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
int
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
auto
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
int
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
auto
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
...
...
profiler/src/profile_conv_fwd_bias_relu_add.cpp
View file @
e72c0c43
...
@@ -6,25 +6,25 @@
...
@@ -6,25 +6,25 @@
#include <half.hpp>
#include <half.hpp>
#include "profile_conv_fwd_bias_relu_add_impl.hpp"
#include "profile_conv_fwd_bias_relu_add_impl.hpp"
enum
ConvDataType
enum
struct
ConvDataType
{
{
F32_F32_F32
,
// 0
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
F16_F16_F16
,
// 1
};
};
enum
ConvInputLayout
enum
struct
ConvInputLayout
{
{
NCHW
,
// 0
NCHW
,
// 0
NHWC
,
// 1
NHWC
,
// 1
};
};
enum
ConvWeightLayout
enum
struct
ConvWeightLayout
{
{
KCYX
,
// 0
KCYX
,
// 0
KYXC
,
// 1
KYXC
,
// 1
};
};
enum
ConvOutputLayout
enum
struct
ConvOutputLayout
{
{
NKHW
,
// 0
NKHW
,
// 0
NHWK
,
// 1
NHWK
,
// 1
...
@@ -49,10 +49,10 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
...
@@ -49,10 +49,10 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
exit
(
1
);
exit
(
1
);
}
}
const
int
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
int
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
auto
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
int
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
auto
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
...
...
profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp
View file @
e72c0c43
...
@@ -6,25 +6,25 @@
...
@@ -6,25 +6,25 @@
#include <half.hpp>
#include <half.hpp>
#include "profile_conv_fwd_bias_relu_atomic_add_impl.hpp"
#include "profile_conv_fwd_bias_relu_atomic_add_impl.hpp"
enum
ConvDataType
enum
struct
ConvDataType
{
{
F32_F32_F32
,
// 0
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
F16_F16_F16
,
// 1
};
};
enum
ConvInputLayout
enum
struct
ConvInputLayout
{
{
NCHW
,
// 0
NCHW
,
// 0
NHWC
,
// 1
NHWC
,
// 1
};
};
enum
ConvWeightLayout
enum
struct
ConvWeightLayout
{
{
KCYX
,
// 0
KCYX
,
// 0
KYXC
,
// 1
KYXC
,
// 1
};
};
enum
ConvOutputLayout
enum
struct
ConvOutputLayout
{
{
NKHW
,
// 0
NKHW
,
// 0
NHWK
,
// 1
NHWK
,
// 1
...
@@ -49,10 +49,10 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[])
...
@@ -49,10 +49,10 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[])
exit
(
1
);
exit
(
1
);
}
}
const
int
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
int
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
auto
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
int
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
auto
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
...
...
profiler/src/profile_gemm.cpp
View file @
e72c0c43
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include <half.hpp>
#include <half.hpp>
#include "profile_gemm_impl.hpp"
#include "profile_gemm_impl.hpp"
enum
GemmMatrixLayout
enum
struct
GemmMatrixLayout
{
{
MK_KN_MN
,
// 0
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
MK_NK_MN
,
// 1
...
@@ -18,7 +18,7 @@ enum GemmMatrixLayout
...
@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM
,
// 7
KM_NK_NM
,
// 7
};
};
enum
GemmDataType
enum
struct
GemmDataType
{
{
F32_F32_F32
,
// 0
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
F16_F16_F16
,
// 1
...
@@ -45,8 +45,8 @@ int profile_gemm(int argc, char* argv[])
...
@@ -45,8 +45,8 @@ int profile_gemm(int argc, char* argv[])
exit
(
1
);
exit
(
1
);
}
}
const
int
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
...
@@ -223,6 +223,26 @@ int profile_gemm(int argc, char* argv[])
...
@@ -223,6 +223,26 @@ int profile_gemm(int argc, char* argv[])
(
StrideC
<
0
)
?
N
:
StrideC
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
KBatch
);
}
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
...
@@ -243,6 +263,66 @@ int profile_gemm(int argc, char* argv[])
...
@@ -243,6 +263,66 @@ int profile_gemm(int argc, char* argv[])
(
StrideC
<
0
)
?
N
:
StrideC
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
KBatch
);
}
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
...
@@ -263,6 +343,46 @@ int profile_gemm(int argc, char* argv[])
...
@@ -263,6 +343,46 @@ int profile_gemm(int argc, char* argv[])
(
StrideC
<
0
)
?
N
:
StrideC
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
KBatch
);
}
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
}
else
else
{
{
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
...
...
profiler/src/profile_gemm_bias_2d.cpp
View file @
e72c0c43
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include <half.hpp>
#include <half.hpp>
#include "profile_gemm_bias_2d_impl.hpp"
#include "profile_gemm_bias_2d_impl.hpp"
enum
GemmMatrixLayout
enum
struct
GemmMatrixLayout
{
{
MK_KN_MN
,
// 0
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
MK_NK_MN
,
// 1
...
@@ -18,7 +18,7 @@ enum GemmMatrixLayout
...
@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM
,
// 7
KM_NK_NM
,
// 7
};
};
enum
GemmDataType
enum
struct
GemmDataType
{
{
F32_F32_F32
,
// 0
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
F16_F16_F16
,
// 1
...
@@ -45,8 +45,8 @@ int profile_gemm_bias_2d(int argc, char* argv[])
...
@@ -45,8 +45,8 @@ int profile_gemm_bias_2d(int argc, char* argv[])
exit
(
1
);
exit
(
1
);
}
}
const
int
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
...
...
profiler/src/profile_gemm_bias_relu.cpp
View file @
e72c0c43
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include <half.hpp>
#include <half.hpp>
#include "profile_gemm_bias_relu_impl.hpp"
#include "profile_gemm_bias_relu_impl.hpp"
enum
GemmMatrixLayout
enum
struct
GemmMatrixLayout
{
{
MK_KN_MN
,
// 0
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
MK_NK_MN
,
// 1
...
@@ -18,7 +18,7 @@ enum GemmMatrixLayout
...
@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM
,
// 7
KM_NK_NM
,
// 7
};
};
enum
GemmDataType
enum
struct
GemmDataType
{
{
F32_F32_F32
,
// 0
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
F16_F16_F16
,
// 1
...
@@ -43,8 +43,8 @@ int profile_gemm_bias_relu(int argc, char* argv[])
...
@@ -43,8 +43,8 @@ int profile_gemm_bias_relu(int argc, char* argv[])
exit
(
1
);
exit
(
1
);
}
}
const
int
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
...
...
Prev
1
…
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment