Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c6891e12
Commit
c6891e12
authored
Jul 01, 2022
by
rocking
Browse files
Merge branch 'develop' into standalone-layernorm
parents
f591ad27
8e374781
Changes
296
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
149 additions
and
136 deletions
+149
-136
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.cpp
.../reduce/device_reduce_instance_threadwise_f16_f16_f16.cpp
+2
-2
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.cpp
.../reduce/device_reduce_instance_threadwise_f16_f32_f16.cpp
+2
-2
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.cpp
.../reduce/device_reduce_instance_threadwise_f32_f32_f32.cpp
+2
-2
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.cpp
.../reduce/device_reduce_instance_threadwise_f32_f64_f32.cpp
+2
-2
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.cpp
.../reduce/device_reduce_instance_threadwise_f64_f64_f64.cpp
+2
-2
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.cpp
...pu/reduce/device_reduce_instance_threadwise_i8_i32_i8.cpp
+2
-2
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.cpp
...gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.cpp
+2
-2
profiler/CMakeLists.txt
profiler/CMakeLists.txt
+2
-0
profiler/include/profile_batched_gemm_impl.hpp
profiler/include/profile_batched_gemm_impl.hpp
+30
-19
profiler/include/profile_batched_gemm_reduce_impl.hpp
profiler/include/profile_batched_gemm_reduce_impl.hpp
+7
-8
profiler/include/profile_conv_bwd_weight_impl.hpp
profiler/include/profile_conv_bwd_weight_impl.hpp
+4
-4
profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp
profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp
+3
-3
profiler/include/profile_conv_fwd_bias_relu_impl.hpp
profiler/include/profile_conv_fwd_bias_relu_impl.hpp
+3
-3
profiler/include/profile_convnd_bwd_data_impl.hpp
profiler/include/profile_convnd_bwd_data_impl.hpp
+15
-16
profiler/include/profile_gemm_add_add_fastgelu_impl.hpp
profiler/include/profile_gemm_add_add_fastgelu_impl.hpp
+21
-22
profiler/include/profile_gemm_bias_2d_impl.hpp
profiler/include/profile_gemm_bias_2d_impl.hpp
+11
-12
profiler/include/profile_gemm_bias_add_reduce_impl.hpp
profiler/include/profile_gemm_bias_add_reduce_impl.hpp
+7
-8
profiler/include/profile_gemm_bias_relu_add_impl.hpp
profiler/include/profile_gemm_bias_relu_add_impl.hpp
+7
-8
profiler/include/profile_gemm_bias_relu_impl.hpp
profiler/include/profile_gemm_bias_relu_impl.hpp
+7
-8
profiler/include/profile_gemm_impl.hpp
profiler/include/profile_gemm_impl.hpp
+18
-11
No files found.
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.cpp
View file @
c6891e12
...
...
@@ -6,7 +6,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_reduce_
instance
{
namespace
instance
{
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...
...
@@ -36,7 +36,7 @@ ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
2
,
1
);
// clang-format on
}
// namespace
device_reduce_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.cpp
View file @
c6891e12
...
...
@@ -6,7 +6,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_reduce_
instance
{
namespace
instance
{
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...
...
@@ -24,7 +24,7 @@ ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
2
,
1
);
// clang-format on
}
// namespace
device_reduce_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.cpp
View file @
c6891e12
...
...
@@ -6,7 +6,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_reduce_
instance
{
namespace
instance
{
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...
...
@@ -48,7 +48,7 @@ ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
2
,
1
);
// clang-format on
}
// namespace
device_reduce_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.cpp
View file @
c6891e12
...
...
@@ -6,7 +6,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_reduce_
instance
{
namespace
instance
{
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...
...
@@ -24,7 +24,7 @@ ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
2
,
1
);
// clang-format on
}
// namespace
device_reduce_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.cpp
View file @
c6891e12
...
...
@@ -6,7 +6,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_reduce_
instance
{
namespace
instance
{
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...
...
@@ -48,7 +48,7 @@ ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID
(
double
,
double
,
double
,
4
,
0
,
1
,
2
,
1
);
// clang-format on
}
// namespace
device_reduce_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.cpp
View file @
c6891e12
...
...
@@ -6,7 +6,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_reduce_
instance
{
namespace
instance
{
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...
...
@@ -21,7 +21,7 @@ ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1);
// clang-format on
// clang-format on
}
// namespace
device_reduce_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.cpp
View file @
c6891e12
...
...
@@ -6,7 +6,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_reduce_
instance
{
namespace
instance
{
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...
...
@@ -36,7 +36,7 @@ ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
4
,
0
,
1
,
2
,
1
);
// clang-format on
}
// namespace
device_reduce_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
profiler/CMakeLists.txt
View file @
c6891e12
...
...
@@ -22,6 +22,7 @@ set(PROFILER_SOURCE
src/profile_conv_bwd_weight.cpp
src/profile_batched_gemm_reduce.cpp
src/profile_gemm_add_add_fastgelu.cpp
src/profile_normalization.cpp
)
add_executable
(
ckProfiler
${
PROFILER_SOURCE
}
)
...
...
@@ -46,4 +47,5 @@ target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_convnd_bwd_data_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_bwd_weight_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_normalization_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_reduce_instance
)
profiler/include/profile_batched_gemm_impl.hpp
View file @
c6891e12
...
...
@@ -10,7 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/
device_
batched_gemm
_instance
.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
...
...
@@ -34,6 +34,9 @@ bool profile_batched_gemm_impl(int do_verification,
int
M
,
int
N
,
int
K
,
int
BatchStrideA
,
int
BatchStrideB
,
int
BatchStrideC
,
int
StrideA
,
int
StrideB
,
int
StrideC
,
...
...
@@ -45,25 +48,28 @@ bool profile_batched_gemm_impl(int do_verification,
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
std
::
size_t
batch_stride
,
auto
layout
)
{
if
(
is_same
<
decltype
(
layout
),
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
row
*
stride
,
stride
,
1
}));
std
::
vector
<
std
::
size_t
>
({
batch_
stride
,
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
col
*
stride
,
1
,
stride
}));
std
::
vector
<
std
::
size_t
>
({
batch_
stride
,
1
,
stride
}));
}
};
Tensor
<
ADataType
>
a_g_m_k
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_g_k_n
(
f_host_tensor_descriptor
(
BatchCount
,
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
ADataType
>
a_g_m_k
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
K
,
StrideA
,
BatchStrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_g_k_n
(
f_host_tensor_descriptor
(
BatchCount
,
K
,
N
,
StrideB
,
BatchStrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_g_m_n_host_result
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
StrideC
,
BatchStrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_g_m_n_device_result
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
StrideC
,
BatchStrideC
,
CLayout
{}));
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_g_k_n: "
<<
b_g_k_n
.
mDesc
<<
std
::
endl
;
...
...
@@ -116,19 +122,21 @@ bool profile_batched_gemm_impl(int do_verification,
b_device_buf
.
ToDevice
(
b_g_k_n
.
mData
.
data
());
c_device_buf
.
ToDevice
(
c_g_m_n_device_result
.
mData
.
data
());
// add device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
get_device_batched_gemm_instances
<
ADataType
,
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
()
;
AElementOp
,
BElementOp
,
CElementOp
>
;
if
(
op_ptrs
.
size
()
<=
0
)
{
throw
std
::
runtime_error
(
"wrong! no device GEMM instance found"
);
}
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
float
best_ave_time
=
0
;
...
...
@@ -148,6 +156,9 @@ bool profile_batched_gemm_impl(int do_verification,
StrideA
,
StrideB
,
StrideC
,
BatchStrideA
,
BatchStrideB
,
BatchStrideC
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
...
...
profiler/include/profile_batched_gemm_reduce_impl.hpp
View file @
c6891e12
...
...
@@ -19,7 +19,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_
instance
{
namespace
instance
{
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
...
...
@@ -44,7 +44,7 @@ void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn
void
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmReduceNoOpPtr
>&
);
}
// namespace
device_gemm_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -208,8 +208,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
b_device_buf
.
ToDevice
(
b_g_k_n
.
mData
.
data
());
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmReduceNoOpPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
instance
::
DeviceGemmReduceNoOpPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
...
...
@@ -218,7 +217,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances
(
gemm_ptrs
);
}
...
...
@@ -226,7 +225,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances
(
gemm_ptrs
);
}
...
...
@@ -234,7 +233,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances
(
gemm_ptrs
);
}
...
...
@@ -242,7 +241,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances
(
gemm_ptrs
);
}
...
...
profiler/include/profile_conv_bwd_weight_impl.hpp
View file @
c6891e12
...
...
@@ -18,7 +18,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_bwd_weight_
instance
{
namespace
instance
{
using
DeviceConvBwdWeightNoOpPtr
=
DeviceConvBwdWeightPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -31,7 +31,7 @@ void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(
void
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
DeviceConvBwdWeightNoOpPtr
>&
);
}
// namespace
device_conv2d_bwd_weight_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -165,14 +165,14 @@ bool profile_conv_bwd_weight_impl(int do_verification,
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
float
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
float
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_weight_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances
(
conv_ptrs
);
}
else
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
ck
::
half_t
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_weight_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances
(
conv_ptrs
);
}
...
...
profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp
View file @
c6891e12
...
...
@@ -17,7 +17,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_fwd_bias_activation_add_
instance
{
namespace
instance
{
using
DeviceConvFwdBiasReluAddPtr
=
DeviceConvFwdBiasActivationAddPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluAddPtr =
void
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
DeviceConvFwdBiasReluAddPtr
>&
);
}
// namespace
device_conv2d_fwd_bias_activation_add_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -179,7 +179,7 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
ck
::
half_t
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_bias_activation_add_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
}
...
...
profiler/include/profile_conv_fwd_bias_relu_impl.hpp
View file @
c6891e12
...
...
@@ -17,7 +17,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_fwd_bias_activation_
instance
{
namespace
instance
{
using
DeviceConvFwdBiasReluPtr
=
DeviceConvFwdBiasActivationPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluPtr =
void
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
DeviceConvFwdBiasReluPtr
>&
);
}
// namespace
device_conv2d_fwd_bias_activation_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -169,7 +169,7 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
ck
::
half_t
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_bias_activation_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
}
...
...
profiler/include/profile_convnd_bwd_data_impl.hpp
View file @
c6891e12
...
...
@@ -22,7 +22,7 @@ using INT8 = int8_t;
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_bwd_data_
instance
{
namespace
instance
{
using
DeviceConvBwdDataNoOpPtr
=
DeviceConvBwdDataPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -54,15 +54,14 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
std
::
vector
<
DeviceConvBwdDataNoOpPtr
>&
);
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances
(
std
::
vector
<
DeviceConvBwdDataNoOpPtr
>&
);
}
// namespace
device_conv2d_bwd_data_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
namespace
ck
{
namespace
profiler
{
using
DeviceConvBwdDataNoOpPtr
=
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_instance
::
DeviceConvBwdDataNoOpPtr
;
using
DeviceConvBwdDataNoOpPtr
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceConvBwdDataNoOpPtr
;
template
<
typename
InLayout
>
HostTensorDescriptor
get_input_host_tensor_descriptor
(
const
std
::
vector
<
std
::
size_t
>&
dims
,
...
...
@@ -144,15 +143,15 @@ void get_device_conv_bwd_data_op_ptr(
switch
(
num_dim_spatial
)
{
case
1
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances
(
conv_ptrs
);
break
;
case
2
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
conv_ptrs
);
break
;
case
3
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
conv_ptrs
);
break
;
default:
break
;
...
...
@@ -165,15 +164,15 @@ void get_device_conv_bwd_data_op_ptr(
switch
(
num_dim_spatial
)
{
case
1
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances
(
conv_ptrs
);
break
;
case
2
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
(
conv_ptrs
);
break
;
case
3
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
conv_ptrs
);
break
;
default:
break
;
...
...
@@ -186,15 +185,15 @@ void get_device_conv_bwd_data_op_ptr(
switch
(
num_dim_spatial
)
{
case
1
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances
(
conv_ptrs
);
break
;
case
2
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
conv_ptrs
);
break
;
case
3
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
conv_ptrs
);
break
;
default:
break
;
...
...
@@ -207,15 +206,15 @@ void get_device_conv_bwd_data_op_ptr(
switch
(
num_dim_spatial
)
{
case
1
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances
(
conv_ptrs
);
break
;
case
2
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
conv_ptrs
);
break
;
case
3
:
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances
(
conv_ptrs
);
break
;
default:
break
;
...
...
profiler/include/profile_gemm_add_add_fastgelu_impl.hpp
View file @
c6891e12
...
...
@@ -10,13 +10,12 @@
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/
device_
gemm_add_add_fastgelu
_instance
.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/host_tensor/host_conv.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace
ck
{
...
...
@@ -30,9 +29,7 @@ template <typename ADataType,
typename
EDataType
,
typename
ALayout
,
typename
BLayout
,
typename
D0Layout
,
typename
D1Layout
,
typename
ELayout
>
typename
DELayout
>
// assume Ds and E have same layout
bool
profile_gemm_add_add_fastgelu_impl
(
int
do_verification
,
int
init_method
,
bool
/*do_log*/
,
...
...
@@ -62,10 +59,10 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
D0DataType
>
d0_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD0
,
D
0
Layout
{}));
Tensor
<
D1DataType
>
d1_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD1
,
D
1
Layout
{}));
Tensor
<
EDataType
>
e_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
Tensor
<
EDataType
>
e_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
Tensor
<
D0DataType
>
d0_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD0
,
D
E
Layout
{}));
Tensor
<
D1DataType
>
d1_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD1
,
D
E
Layout
{}));
Tensor
<
EDataType
>
e_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
D
ELayout
{}));
Tensor
<
EDataType
>
e_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
D
ELayout
{}));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
...
...
@@ -100,19 +97,21 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
const
auto
b_element_op
=
BElementOp
{};
const
auto
cde_element_op
=
CDEElementOp
{};
// add device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
get_device_gemm_add_add_fastgelu_instances
<
ADataType
,
BDataType
,
AccDataType
,
D0DataType
,
D1DataType
,
EDataType
,
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
D0Layout
,
D1Layout
,
ELayout
>
();
DELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
,
D1DataType
>
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
...
...
profiler/include/profile_gemm_bias_2d_impl.hpp
View file @
c6891e12
...
...
@@ -17,7 +17,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_
instance
{
namespace
instance
{
using
DeviceGemmAlphaBetaPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmBiasPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -48,7 +48,7 @@ void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances(
void
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmAlphaBetaPtr
>&
);
}
// namespace
device_gemm_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -159,8 +159,7 @@ void profile_gemm_bias_2d_impl(int do_verification,
c_device_buf
.
ToDevice
(
c_m_n_device_result
.
mData
.
data
());
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmAlphaBetaPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
instance
::
DeviceGemmAlphaBetaPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
...
...
@@ -169,28 +168,28 @@ void profile_gemm_bias_2d_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances
(
gemm_ptrs
);
}
}
...
...
@@ -201,28 +200,28 @@ void profile_gemm_bias_2d_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances
(
gemm_ptrs
);
}
}
...
...
profiler/include/profile_gemm_bias_add_reduce_impl.hpp
View file @
c6891e12
...
...
@@ -19,7 +19,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_
instance
{
namespace
instance
{
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
...
...
@@ -45,7 +45,7 @@ void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f
void
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmBiasAddReduceNoOpPtr
>&
);
}
// namespace
device_gemm_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -236,8 +236,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
d0_device_buf
.
ToDevice
(
d0_m_n
.
mData
.
data
());
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmBiasAddReduceNoOpPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
instance
::
DeviceGemmBiasAddReduceNoOpPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
...
...
@@ -246,7 +245,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances
(
gemm_ptrs
);
}
...
...
@@ -254,7 +253,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances
(
gemm_ptrs
);
}
...
...
@@ -262,7 +261,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
}
...
...
@@ -270,7 +269,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances
(
gemm_ptrs
);
}
...
...
profiler/include/profile_gemm_bias_relu_add_impl.hpp
View file @
c6891e12
...
...
@@ -18,7 +18,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_
instance
{
namespace
instance
{
using
DeviceGemmBiasReluAddPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmBiasActivationAddPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -34,7 +34,7 @@ void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances(
void
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmBiasReluAddPtr
>&
);
}
// namespace
device_gemm_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -158,8 +158,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
c1_m_n_device_buf
.
ToDevice
(
c1_m_n
.
mData
.
data
());
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmBiasReluAddPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
instance
::
DeviceGemmBiasReluAddPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
...
...
@@ -168,7 +167,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
}
...
...
@@ -176,7 +175,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances
(
gemm_ptrs
);
}
...
...
@@ -184,7 +183,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances
(
gemm_ptrs
);
}
...
...
@@ -192,7 +191,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances
(
gemm_ptrs
);
}
...
...
profiler/include/profile_gemm_bias_relu_impl.hpp
View file @
c6891e12
...
...
@@ -18,7 +18,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_
instance
{
namespace
instance
{
using
DeviceGemmBiasReluPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmBiasActivationPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -34,7 +34,7 @@ void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances(
void
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmBiasReluPtr
>&
);
}
// namespace
device_gemm_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -144,8 +144,7 @@ void profile_gemm_bias_relu_impl(int do_verification,
c0_n_device_buf
.
ToDevice
(
c0_n
.
mData
.
data
());
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmBiasReluPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
instance
::
DeviceGemmBiasReluPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
...
...
@@ -154,28 +153,28 @@ void profile_gemm_bias_relu_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances
(
gemm_ptrs
);
}
}
...
...
profiler/include/profile_gemm_impl.hpp
View file @
c6891e12
...
...
@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/
device_gemm_instance
.hpp"
#include "ck/library/tensor_operation_instance/gpu/
gemm
.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
...
...
@@ -94,14 +94,21 @@ int profile_gemm_impl(int do_verification,
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_device_buf
.
ToDevice
(
c_m_n_device_result
.
mData
.
data
());
// add device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
get_device_gemm_instances
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
();
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
if
(
op_ptrs
.
size
()
<=
0
)
{
throw
std
::
runtime_error
(
"wrong! no device GEMM instance found"
);
}
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
// Run reference GEMM
if
(
do_verification
)
...
...
@@ -141,9 +148,9 @@ int profile_gemm_impl(int do_verification,
StrideA
,
StrideB
,
StrideC
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}
,
c
k
::
tensor_operation
::
element_wise
::
PassThrough
{}
);
a_element_op
,
b_element_op
,
c
_element_op
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
...
...
Prev
1
…
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment