Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cd929111
Commit
cd929111
authored
Dec 05, 2021
by
Chao Liu
Browse files
refactor
parent
c345719a
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
164 additions
and
82 deletions
+164
-82
device_operation/device_conv2d_fwd_xdl_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp
...v2d_fwd_xdl_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp
+3
-3
device_operation/include/device_conv2d_fwd_xdl_bias_activation_add_nhwc_kyxc_nhwk.hpp
...ice_conv2d_fwd_xdl_bias_activation_add_nhwc_kyxc_nhwk.hpp
+50
-1
device_operation/include/device_conv_fwd_bias_activation_add.hpp
...operation/include/device_conv_fwd_bias_activation_add.hpp
+8
-4
device_operation/include/element_wise_operation.hpp
device_operation/include/element_wise_operation.hpp
+30
-0
example/4_conv_xdl_bias_relu_add/conv_xdl_bias_relu_add.cpp
example/4_conv_xdl_bias_relu_add/conv_xdl_bias_relu_add.cpp
+2
-4
profiler/CMakeLists.txt
profiler/CMakeLists.txt
+20
-21
profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp
profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp
+36
-32
profiler/include/profile_conv_fwd_impl.hpp
profiler/include/profile_conv_fwd_impl.hpp
+6
-6
profiler/profile_conv_fwd_bias_relu_add.cpp
profiler/profile_conv_fwd_bias_relu_add.cpp
+7
-7
profiler/profiler.cpp
profiler/profiler.cpp
+2
-4
No files found.
device_operation/device_conv2d_fwd_xdl_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp
View file @
cd929111
...
...
@@ -39,8 +39,8 @@ using device_conv2d_fwd_xdl_bias_relu_add_nhwc_kyxc_nhwk_f16_instances = std::tu
// clang-format on
>
;
add_device_conv2d_fwd_bias_relu_add_xdl_nhwc_kyxc_nhwk_fp16_instances
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddRe
L
uAdd
>>&
void
add_device_conv2d_fwd_bias_relu_add_xdl_nhwc_kyxc_nhwk_fp16_instances
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddRe
l
uAdd
>>&
instance_container
)
{
using
Instances
=
device_conv2d_fwd_xdl_bias_relu_add_nhwc_kyxc_nhwk_f16_instances
;
...
...
@@ -52,7 +52,7 @@ add_device_conv2d_fwd_bias_relu_add_xdl_nhwc_kyxc_nhwk_fp16_instances(
auto
instance
=
Instance
{};
device_conv_instances
.
push_back
(
std
::
make_unique
<
Instance
>
(
instance
));
instance_container
.
push_back
(
std
::
make_unique
<
Instance
>
(
instance
));
});
}
...
...
device_operation/include/device_conv2d_fwd_xdl_bias_activation_add_nhwc_kyxc_nhwk.hpp
View file @
cd929111
...
...
@@ -53,7 +53,9 @@ template <typename InDataType,
bool
ABlockLdsAddExtraM
,
bool
BBlockLdsAddExtraN
>
struct
DeviceConv2dFwdXdl_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
BaseOperator
:
public
DeviceConvFwdBiasActivationAdd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
DeviceOp
=
DeviceConv2dFwdXdl_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
;
...
...
@@ -618,6 +620,53 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Out
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
const
void
*
p_wei_grid
,
void
*
p_out_grid
,
const
void
*
p_bias_grid
,
const
void
*
p_resi_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
const
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
OutDataType
*>
(
p_out_grid
),
static_cast
<
const
OutDataType
*>
(
p_bias_grid
),
static_cast
<
const
OutDataType
*>
(
p_resi_grid
),
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
1
,
1
,
in_element_op
,
wei_element_op
,
out_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
...
...
device_operation/include/device_conv_fwd_bias_activation_add.hpp
View file @
cd929111
#ifndef DEVICE_CONV_FWD_HPP
#define DEVICE_CONV_FWD_HPP
#ifndef DEVICE_CONV_FWD_
BIAS_ACTIVATION_ADD_
HPP
#define DEVICE_CONV_FWD_
BIAS_ACTIVATION_ADD_
HPP
#include <iostream>
#include "device_base.hpp"
...
...
@@ -17,6 +17,8 @@ struct DeviceConvFwdBiasActivationAdd : public BaseOperator
MakeArgumentPointer
(
const
void
*
p_in
,
const
void
*
p_wei
,
void
*
p_out
,
const
void
*
p_bias
,
const
void
*
p_resi
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
...
...
@@ -37,8 +39,10 @@ struct DeviceConvFwdBiasActivationAdd : public BaseOperator
template
<
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
using
DeviceConvFwdPtr
=
std
::
unique_ptr
<
DeviceConvFwd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>>
;
using
DeviceConvFwdBiasActivationAddPtr
=
std
::
unique_ptr
<
DeviceConvFwdBiasActivationAdd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
device_operation/include/element_wise_operation.hpp
View file @
cd929111
...
...
@@ -14,6 +14,36 @@ struct PassThrough
}
};
struct
AddReluAdd
{
template
<
typename
T1
,
typename
T2
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
float
b
=
v0
+
v1
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
}
template
<
typename
T1
,
typename
T2
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
float c = b + v2;
return c;
#else
float
b
=
v1
+
v2
;
float
c
=
(
v0
>
-
v1
)
?
b
+
v0
:
v2
;
return
c
;
#endif
}
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
...
...
example/4_conv_xdl_bias_relu_add/conv_xdl_bias_relu_add.cpp
View file @
cd929111
...
...
@@ -118,10 +118,8 @@ struct BiasReluAdd
return c;
#else
float
a
=
v1
+
v2
;
float
b
=
v2
;
float
c
=
(
v0
>
-
v1
)
?
a
+
v0
:
v2
;
float
b
=
v1
+
v2
;
float
c
=
(
v0
>
-
v1
)
?
b
+
v0
:
v2
;
return
c
;
#endif
...
...
profiler/CMakeLists.txt
View file @
cd929111
...
...
@@ -30,37 +30,36 @@ target_compile_features(device_gemm_instance PUBLIC)
set_target_properties
(
device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_gemm_instance LIBRARY DESTINATION lib
)
# device_conv_instance
set
(
DEVICE_CONV_INSTANCE_SOURCE
# device_conv
2d_fwd
_instance
set
(
DEVICE_CONV
2D_FWD
_INSTANCE_SOURCE
${
PROJECT_SOURCE_DIR
}
/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instance.cpp;
)
add_library
(
device_conv_instance SHARED
${
DEVICE_CONV_INSTANCE_SOURCE
}
)
target_include_directories
(
device_conv_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
target_compile_features
(
device_conv_instance PUBLIC
)
set_target_properties
(
device_conv_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_conv_instance LIBRARY DESTINATION lib
)
add_library
(
device_conv
2d_fwd
_instance SHARED
${
DEVICE_CONV
2D_FWD
_INSTANCE_SOURCE
}
)
target_include_directories
(
device_conv
2d_fwd
_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
target_compile_features
(
device_conv
2d_fwd
_instance PUBLIC
)
set_target_properties
(
device_conv
2d_fwd
_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_conv
2d_fwd
_instance LIBRARY DESTINATION lib
)
#
# device_conv_bias_relu_add_instance
#
set(DEVICE_CONV_BIAS_RELU_ADD_INSTANCE_SOURCE
#
${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp;
#
)
#
#
add_library(device_conv_bias_relu_add_instance SHARED ${DEVICE_CONV_BIAS_RELU_ADD_INSTANCE_SOURCE})
#
target_include_directories(device_conv_bias_relu_add_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
#
target_compile_features(device_conv_bias_relu_add_instance PUBLIC)
#
set_target_properties(device_conv_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
#
install(TARGETS device_conv_bias_relu_add_instance LIBRARY DESTINATION lib)
# device_conv
2d_fwd
_bias_relu_add_instance
set
(
DEVICE_CONV
2D_FWD
_BIAS_RELU_ADD_INSTANCE_SOURCE
${
PROJECT_SOURCE_DIR
}
/device_operation/device_conv2d_fwd_xdl_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp;
)
add_library
(
device_conv
2d_fwd
_bias_relu_add_instance SHARED
${
DEVICE_CONV
2D_FWD
_BIAS_RELU_ADD_INSTANCE_SOURCE
}
)
target_include_directories
(
device_conv
2d_fwd
_bias_relu_add_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
target_compile_features
(
device_conv
2d_fwd
_bias_relu_add_instance PUBLIC
)
set_target_properties
(
device_conv
2d_fwd
_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_conv
2d_fwd
_bias_relu_add_instance LIBRARY DESTINATION lib
)
# ck_profiler
#set(PROFILER_SOURCE profiler.cpp profile_gemm.cpp profile_conv.cpp profile_conv_bias_relu_add.cpp)
set
(
PROFILER_SOURCE profiler.cpp profile_gemm.cpp profile_conv_fwd.cpp
)
set
(
PROFILER_SOURCE profiler.cpp profile_gemm.cpp profile_conv_fwd.cpp profile_conv_fwd_bias_relu_add.cpp
)
add_executable
(
ckProfiler
${
PROFILER_SOURCE
}
)
target_link_libraries
(
ckProfiler PRIVATE host_tensor
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_conv_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_conv_bias_relu_add_instance)
target_link_libraries
(
ckProfiler PRIVATE device_conv
2d_fwd
_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_conv
2d_fwd
_bias_relu_add_instance
)
profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp
View file @
cd929111
...
...
@@ -6,23 +6,21 @@
#include "host_conv.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "device_conv.hpp"
#include "device_conv
_fwd
.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_fwd_
bias_activation_add_
instance
{
namespace
device_conv2d_fwd_instance
{
using
DeviceConvFwdNoOpPtr
=
DeviceConvFwdPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
add_device_conv2d_fwd_bias_relu_add_xdl_nhwc_kyxc_nhwk_fp16_instances
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReLuAdd
>>&
instance_container
)
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_fp16_instances
(
std
::
vector
<
DeviceConvFwdNoOpPtr
>&
);
}
// namespace device_conv2d_fwd_
bias_activation_add_
instance
}
// namespace device_conv2d_fwd_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -37,7 +35,7 @@ template <int NDimSpatial,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
void
profile_conv
(
int
do_verification
,
void
profile_conv
_fwd_bias_relu_bias_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
int
nrepeat
,
...
...
@@ -129,20 +127,26 @@ void profile_conv(int do_verification,
// add device Conv instances
std
::
vector
<
DeviceConvFwdNoOpPtr
>
conv_ptrs
;
ck
::
tensor_operation
::
device
::
device_conv_instance
::
add_device_conv_fwd_instance
<
2
,
InDataType
,
WeiDataType
,
OutDataType
,
InLayout
,
WeiLayout
,
OutLayout
>
(
conv_ptrs
);
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
float
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
float
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
float
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_instance
::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_fp32_instances
(
conv_ptrs
);
}
else
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
ck
::
half_t
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_instance
::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_fp16_instances
(
conv_ptrs
);
ck
::
tensor_operation
::
device
::
device_conv_instance
::
ck
::
tensor_operation
::
device
::
device_conv
2d_fwd
_instance
::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_fp16_instances
(
conv_ptrs
);
ck
::
tensor_operation
::
device
::
device_conv_instance
::
ck
::
tensor_operation
::
device
::
device_conv
2d_fwd
_instance
::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_fp16_instances
(
conv_ptrs
);
}
if
(
conv_ptrs
.
size
()
<=
0
)
{
...
...
profiler/include/profile_conv_fwd_impl.hpp
View file @
cd929111
...
...
@@ -135,16 +135,16 @@ void profile_conv_fwd_impl(int do_verification,
// add device Conv instances
std
::
vector
<
DeviceConvFwdNoOpPtr
>
conv_ptrs
;
if
constexpr
(
is_same_v
<
remove_cv_t
<
InDataType
>
,
float
>
&&
is_same_v
<
remove_cv_t
<
WeiDataType
>
,
float
>
&&
is_same_v
<
remove_cv_t
<
OutDataType
>
,
float
>
)
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
float
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
float
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
float
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_instance
::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_fp32_instances
(
conv_ptrs
);
}
else
if
constexpr
(
ck
::
is_same_v
<
remove_cv_t
<
InDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
remove_cv_t
<
WeiDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
remove_cv_t
<
OutDataType
>
,
ck
::
half_t
>
)
else
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
ck
::
half_t
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_instance
::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_fp16_instances
(
conv_ptrs
);
...
...
profiler/profile_conv_fwd_bias_relu_add.cpp
View file @
cd929111
...
...
@@ -84,7 +84,7 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
if
(
data_type
==
ConvDataType
::
F16_F16_F16
&&
in_layout
==
ConvInputLayout
::
NHWC
&&
wei_layout
==
ConvWeightLayout
::
KYXC
&&
out_layout
==
ConvOutputLayout
::
NHWK
)
{
ck
::
profiler
::
profile_conv_fwd_bias_relu_add_impl
e
<
2
,
ck
::
profiler
::
profile_conv_fwd_bias_relu_add_impl
<
2
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
...
...
profiler/profiler.cpp
View file @
cd929111
...
...
@@ -7,7 +7,7 @@
int
profile_gemm
(
int
,
char
*
[]);
int
profile_conv_fwd
(
int
,
char
*
[]);
//
int profile_conv_fwd_bias_relu_add(int, char*[]);
int
profile_conv_fwd_bias_relu_add
(
int
,
char
*
[]);
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
@@ -19,15 +19,13 @@ int main(int argc, char* argv[])
{
return
profile_conv_fwd
(
argc
,
argv
);
}
#if 0
else
if
(
strcmp
(
argv
[
1
],
"conv_fwd_bias_relu_add"
)
==
0
)
{
return
profile_conv_fwd_bias_relu_add
(
argc
,
argv
);
}
#endif
else
{
printf
(
"arg1: tensor operation (conv_fwd: ForwardConvolution)
\n
"
);
printf
(
"arg1: tensor operation (
gemm: GEMM;
conv_fwd: ForwardConvolution)
\n
"
);
return
0
;
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment