Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e9575251
"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "fbbdd0e3be99969beb2dc6048ebae5d523e58503"
Commit
e9575251
authored
Oct 15, 2021
by
Jing Zhang
Browse files
test
parent
da207144
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
81 additions
and
104 deletions
+81
-104
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2_add.hpp
...l/include/tensor_operation/gridwise_gemm_dlops_v2_add.hpp
+0
-1
composable_kernel/include/utility/config.hpp
composable_kernel/include/utility/config.hpp
+1
-1
host/driver_offline/include/device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
...ward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
+5
-5
host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
...ward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
+6
-6
host/driver_offline/include/device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
...ward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
+6
-6
host/driver_offline/src/conv_add_fwd_driver_offline_nchwc.cpp
.../driver_offline/src/conv_add_fwd_driver_offline_nchwc.cpp
+3
-3
host/driver_offline/src/conv_fwd_driver_offline_nchwc.cpp
host/driver_offline/src/conv_fwd_driver_offline_nchwc.cpp
+22
-25
host/driver_offline/src/conv_maxpool_fwd_driver_offline_nchwc.cpp
...ver_offline/src/conv_maxpool_fwd_driver_offline_nchwc.cpp
+34
-54
host/host_tensor/include/host_conv.hpp
host/host_tensor/include/host_conv.hpp
+4
-3
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2_add.hpp
View file @
e9575251
...
...
@@ -997,7 +997,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
}
}
// Bias
if
constexpr
(
bias_type
==
1
)
{
...
...
composable_kernel/include/utility/config.hpp
View file @
e9575251
...
...
@@ -78,7 +78,7 @@
// experimental implementation
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
1
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
0
#endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
...
...
host/driver_offline/include/device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
View file @
e9575251
...
...
@@ -106,16 +106,16 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
#elif
1
constexpr
auto
BlockSize
=
64
;
constexpr
auto
KPerBlock
=
16
;
constexpr
auto
KPerBlock
=
K
;
constexpr
auto
HoPerBlock
=
8
;
constexpr
auto
WoPerBlock
=
32
;
constexpr
auto
E1
=
2
*
9
;
constexpr
auto
E1
=
C0
*
9
;
constexpr
auto
E2
=
1
;
constexpr
auto
K2
=
2
;
constexpr
auto
E1PerBlock
=
2
;
constexpr
auto
E1PerBlock
=
C0
;
constexpr
auto
KPerThread
=
16
;
constexpr
auto
KPerThread
=
K
;
constexpr
auto
HoPerThread
=
2
;
constexpr
auto
WoPerThread
=
2
;
constexpr
auto
EPerThread
=
1
;
...
...
@@ -129,7 +129,7 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
constexpr
auto
BThreadTransferSrcScalarPerVector_E2
=
E2
;
constexpr
auto
CThreadTransferDstScalarPerVector_K
=
8
;
constexpr
auto
CThreadTransferDstScalarPerVector_K
=
K1
;
#endif
const
auto
in_n_c0_hi_wi_c1_desc
=
...
...
host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
View file @
e9575251
...
...
@@ -99,21 +99,21 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1
#elif
1
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
K
;
constexpr
index_t
HoPerBlock
=
8
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
E1
=
2
*
9
;
constexpr
index_t
E1
=
C0
*
Y
*
X
;
constexpr
index_t
E2
=
1
;
constexpr
index_t
K2
=
2
;
constexpr
index_t
E1PerBlock
=
2
;
constexpr
index_t
E1PerBlock
=
C0
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
K
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
EPerThread
=
1
;
using
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2
=
Sequence
<
1
,
9
,
1
,
1
,
E2
>
;
using
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2
=
Sequence
<
1
,
Y
*
X
,
1
,
1
,
E2
>
;
using
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2
=
Sequence
<
1
,
E1PerBlock
,
1
,
KPerBlock
,
1
>
;
...
...
@@ -122,7 +122,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1
constexpr
index_t
BThreadTransferSrcScalarPerVector_E2
=
E2
;
constexpr
index_t
CThreadTransferDstScalarPerVector_K
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector_K
=
K1
;
#endif
const
auto
in_n_c0_hi_wi_c1_desc
=
...
...
host/driver_offline/include/device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
View file @
e9575251
...
...
@@ -66,7 +66,7 @@ void device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1
DeviceMem
out_n_k0_ho_wo_k1_device_buf
(
sizeof
(
TOut
)
*
out_n_k0_ho_wo_k1
.
mDesc
.
GetElementSpace
());
DeviceMem
max_n_k0_hx_wx_k1_device_buf
(
sizeof
(
TOut
)
*
max_n_k0_hx_wx_k1
.
mDesc
.
GetElementSpace
());
max_n_k0_hx_wx_k1
.
mDesc
.
GetElementSpace
());
in_n_c0_hi_wi_c1_device_buf
.
ToDevice
(
in_n_c0_hi_wi_c1
.
mData
.
data
());
wei_k_c0_y_x_c1_device_buf
.
ToDevice
(
wei_k_c0_y_x_c1
.
mData
.
data
());
...
...
@@ -108,16 +108,16 @@ void device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1
#elif
1
constexpr
auto
BlockSize
=
64
;
constexpr
auto
KPerBlock
=
16
;
constexpr
auto
KPerBlock
=
K
;
constexpr
auto
HoPerBlock
=
8
;
constexpr
auto
WoPerBlock
=
32
;
constexpr
auto
E1
=
2
*
9
;
constexpr
auto
E1
=
C0
*
9
;
constexpr
auto
E2
=
1
;
constexpr
auto
K2
=
2
;
constexpr
auto
E1PerBlock
=
2
;
constexpr
auto
E1PerBlock
=
C0
;
constexpr
auto
KPerThread
=
16
;
constexpr
auto
KPerThread
=
K
;
constexpr
auto
HoPerThread
=
2
;
constexpr
auto
WoPerThread
=
2
;
constexpr
auto
EPerThread
=
1
;
...
...
@@ -131,7 +131,7 @@ void device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1
constexpr
auto
BThreadTransferSrcScalarPerVector_E2
=
E2
;
constexpr
auto
CThreadTransferDstScalarPerVector_K
=
8
;
constexpr
auto
CThreadTransferDstScalarPerVector_K
=
K1
;
#endif
const
auto
in_n_c0_hi_wi_c1_desc
=
...
...
host/driver_offline/src/conv_add_fwd_driver_offline_nchwc.cpp
View file @
e9575251
...
...
@@ -113,8 +113,8 @@ int main(int argc, char* argv[])
constexpr
auto
X
=
Number
<
3
>
{};
constexpr
auto
C0
=
Number
<
2
>
{};
constexpr
auto
C1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
2
>
{};
constexpr
auto
K1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
8
>
{};
#elif 0
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
Hi
=
Number
<
270
>
{};
...
...
@@ -123,8 +123,8 @@ int main(int argc, char* argv[])
constexpr
auto
X
=
Number
<
3
>
{};
constexpr
auto
C0
=
Number
<
2
>
{};
constexpr
auto
C1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
2
>
{};
constexpr
auto
K1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
8
>
{};
#elif 0
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
Hi
=
Number
<
135
>
{};
...
...
@@ -133,8 +133,8 @@ int main(int argc, char* argv[])
constexpr
auto
X
=
Number
<
3
>
{};
constexpr
auto
C0
=
Number
<
2
>
{};
constexpr
auto
C1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
2
>
{};
constexpr
auto
K1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
8
>
{};
#elif 1
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
Hi
=
Number
<
32
>
{};
...
...
host/driver_offline/src/conv_fwd_driver_offline_nchwc.cpp
View file @
e9575251
...
...
@@ -45,7 +45,7 @@ int main(int argc, char* argv[])
exit
(
1
);
}
constexpr
index_t
activ_type
=
0
;
constexpr
index_t
activ_type
=
1
;
const
ConvForwardAlgo
algo
=
static_cast
<
ConvForwardAlgo
>
(
std
::
stoi
(
argv
[
1
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
2
]);
...
...
@@ -100,48 +100,45 @@ int main(int argc, char* argv[])
constexpr auto X = Number<3>{};
constexpr auto C0 = Number<2>{};
constexpr auto C1 = Number<8>{};
constexpr auto K0 = Number<2>{};
constexpr auto K1 = Number<8>{};
constexpr auto K0 = Number<8>{};
#elif
1
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
Hi
=
Number
<
540
>
{};
constexpr
auto
Wi
=
Number
<
960
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
constexpr
auto
C0
=
Number
<
2
>
{};
constexpr
auto
C1
=
Number
<
8
>
{};
constexpr
auto
K1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
8
>
{};
#elif
0
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
Hi
=
Number
<
540
>
{};
constexpr
auto
Wi
=
Number
<
960
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
constexpr
auto
C0
=
Number
<
2
>
{};
constexpr
auto
C1
=
Number
<
8
>
{};
constexpr
auto
K1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
8
>
{};
#elif 1
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
Hi
=
Number
<
270
>
{};
constexpr
auto
Wi
=
Number
<
480
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
constexpr
auto
Y
=
Number
<
1
>
{};
constexpr
auto
X
=
Number
<
1
>
{};
constexpr
auto
C0
=
Number
<
2
>
{};
constexpr
auto
C1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
2
>
{};
constexpr
auto
K1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
8
>
{};
#elif 1
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
Hi
=
Number
<
135
>
{};
constexpr
auto
Wi
=
Number
<
240
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
constexpr
auto
C0
=
Number
<
2
>
{};
constexpr
auto
C1
=
Number
<
8
>
{};
constexpr
auto
K1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
8
>
{};
#endif
constexpr
auto
conv_stride_h
=
I1
;
constexpr
auto
conv_stride_w
=
I1
;
constexpr
auto
conv_dilation_h
=
I1
;
constexpr
auto
conv_dilation_w
=
I1
;
#if 0
constexpr auto in_left_pad_h = I1;
constexpr auto in_left_pad_w = I1;
constexpr auto in_right_pad_h = I1;
constexpr auto in_right_pad_w = I1;
#else
constexpr
auto
in_left_pad_h
=
I0
;
constexpr
auto
in_left_pad_w
=
I0
;
constexpr
auto
in_right_pad_h
=
I0
;
constexpr
auto
in_right_pad_w
=
I0
;
#endif
constexpr
auto
YEff
=
(
Y
-
I1
)
*
conv_dilation_h
+
I1
;
constexpr
auto
XEff
=
(
X
-
I1
)
*
conv_dilation_w
+
I1
;
...
...
host/driver_offline/src/conv_maxpool_fwd_driver_offline_nchwc.cpp
View file @
e9575251
...
...
@@ -95,17 +95,17 @@ int main(int argc, char* argv[])
constexpr
index_t
activ_type
=
1
;
#if
1
#if
0
constexpr auto N = Number<1>{};
constexpr auto Hi = Number<1080>{};
constexpr auto Wi = Number<1920>{};
constexpr auto Y = Number<3>{};
constexpr auto X = Number<3>{};
constexpr
auto
C0
=
Number
<
2
>
{};
constexpr auto C0 = Number<
1
>{};
constexpr auto C1 = Number<8>{};
constexpr auto K0 = Number<2>{};
constexpr auto K1 = Number<8>{};
constexpr
auto
K0
=
Number
<
8
>
{};
#elif 1
#elif
0
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
Hi
=
Number
<
540
>
{};
constexpr
auto
Wi
=
Number
<
960
>
{};
...
...
@@ -113,9 +113,9 @@ int main(int argc, char* argv[])
constexpr
auto
X
=
Number
<
3
>
{};
constexpr
auto
C0
=
Number
<
2
>
{};
constexpr
auto
C1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
2
>
{};
constexpr
auto
K1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
8
>
{};
#elif 0
#elif 1
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
Hi
=
Number
<
270
>
{};
constexpr
auto
Wi
=
Number
<
480
>
{};
...
...
@@ -123,28 +123,8 @@ int main(int argc, char* argv[])
constexpr
auto
X
=
Number
<
3
>
{};
constexpr
auto
C0
=
Number
<
2
>
{};
constexpr
auto
C1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
2
>
{};
constexpr
auto
K1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
8
>
{};
#elif 0
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
Hi
=
Number
<
135
>
{};
constexpr
auto
Wi
=
Number
<
240
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
constexpr
auto
C0
=
Number
<
2
>
{};
constexpr
auto
C1
=
Number
<
8
>
{};
constexpr
auto
K1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
8
>
{};
#elif 1
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
Hi
=
Number
<
32
>
{};
constexpr
auto
Wi
=
Number
<
32
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
constexpr
auto
C0
=
Number
<
2
>
{};
constexpr
auto
C1
=
Number
<
8
>
{};
constexpr
auto
K1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
8
>
{};
#endif
constexpr
auto
conv_stride_h
=
I1
;
...
...
@@ -290,39 +270,39 @@ int main(int argc, char* argv[])
{
const
auto
tmp
=
f_make_for_device_nchwc
();
device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1
<
in_data_t
,
acc
_data_t
,
out
_data_t
,
activ_type
>
(
tmp
[
I0
],
// in_lengths_dev
tmp
[
I1
],
// wei_lengths_dev
tmp
[
I2
],
// max_lengths_dev
tmp
[
I3
],
// out_lengths_dev
tmp
[
I4
],
// conv_strides_dev
tmp
[
I5
],
// conv_dilations_dev
tmp
[
I6
],
// in_left_pads_dev
tmp
[
I7
],
// in_right_pads_dev
in
,
wei
,
bias
,
out_device
,
max_device
,
nrepeat
);
device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1
<
in
_data_t
,
acc
_data_t
,
out_data_t
,
activ_type
>
(
tmp
[
I0
],
// in_lengths_dev
tmp
[
I1
],
// wei_lengths_dev
tmp
[
I2
],
// max_lengths_dev
tmp
[
I3
],
// out_lengths_dev
tmp
[
I4
],
// conv_strides_dev
tmp
[
I5
],
// conv_dilations_dev
tmp
[
I6
],
// in_left_pads_dev
tmp
[
I7
],
// in_right_pads_dev
in
,
wei
,
bias
,
out_device
,
max_device
,
nrepeat
);
}
#endif
if
(
do_verification
)
{
host_direct_convolution_maxpool_nchwc
(
in
,
wei
,
bias
,
out_host
,
max_host
,
make_tuple
(
conv_stride_h
,
conv_stride_w
),
make_tuple
(
conv_dilation_h
,
conv_dilation_w
),
make_tuple
(
in_left_pad_h
,
in_left_pad_w
),
make_tuple
(
in_right_pad_h
,
in_right_pad_w
),
activ_type
);
wei
,
bias
,
out_host
,
max_host
,
make_tuple
(
conv_stride_h
,
conv_stride_w
),
make_tuple
(
conv_dilation_h
,
conv_dilation_w
),
make_tuple
(
in_left_pad_h
,
in_left_pad_w
),
make_tuple
(
in_right_pad_h
,
in_right_pad_w
),
activ_type
);
check_error
(
out_host
,
out_device
);
check_error
(
max_host
,
max_device
);
...
...
host/host_tensor/include/host_conv.hpp
View file @
e9575251
...
...
@@ -4,7 +4,7 @@
template
<
typename
T
>
inline
auto
activ
(
T
v
,
const
ck
::
index_t
activ_type
)
{
const
T
alpha
=
0.30000001192092896
;
const
T
alpha
=
0.30000001192092896
;
switch
(
activ_type
)
{
case
0
:
return
v
;
...
...
@@ -147,7 +147,8 @@ void host_direct_convolution_nchwc(const Tensor<TIn>& in,
}
}
}
out
(
n
,
k0
,
ho
,
wo
,
k1
)
=
activ
(
v
,
activ_type
)
+
bias
(
k0
,
k1
);
v
+=
bias
(
k0
,
k1
);
out
(
n
,
k0
,
ho
,
wo
,
k1
)
=
activ
(
v
,
activ_type
);
};
make_ParallelTensorFunctor
(
f_nchw
,
...
...
@@ -275,7 +276,7 @@ void host_direct_convolution_maxpool_nchwc(const Tensor<TIn>& in,
}
v
+=
bias
(
k0
,
k1
);
v
=
activ
(
v
,
activ_type
);
v
=
activ
(
v
,
activ_type
);
out_host
(
n
,
k0
,
ho
,
wo
,
k1
)
=
v
;
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment