Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
9b280cc5
Commit
9b280cc5
authored
Sep 27, 2019
by
Chao Liu
Browse files
remove dead code
parent
98a2cfcc
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
43 additions
and
347 deletions
+43
-347
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+37
-42
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp
..._convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp
+1
-2
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+1
-2
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded.hpp
..._convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded.hpp
+1
-2
driver/include/device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
...ice_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
+0
-296
driver/src/driver.cpp
driver/src/driver.cpp
+3
-3
No files found.
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
9b280cc5
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
#include "device.hpp"
#include "device.hpp"
#include "tensor.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
//#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp"
#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp"
template
<
class
T
,
template
<
class
T
,
...
@@ -20,7 +19,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -20,7 +19,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
Tensor
<
T
>&
out_nkhw
,
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvStrides
,
ConvDilations
,
ConvDilations
,
index_t
nrepeat
)
ck
::
index_t
nrepeat
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -171,46 +170,42 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -171,46 +170,42 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
constexpr
auto
gridwise_conv
=
#if 0
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
<
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
GridSize
,
#else
BlockSize
,
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
T
,
#endif
decltype
(
in_nchw_desc
),
<
GridSize
,
decltype
(
wei_kcyx_desc
),
BlockSize
,
decltype
(
out_nkhw_desc
),
T
,
ConvStrides
,
decltype
(
in_nchw_desc
),
ConvDilations
,
decltype
(
wei_kcyx_desc
),
BPerBlock
,
decltype
(
out_nkhw_desc
),
KPerBlock
,
ConvStrides
,
EPerBlock
,
ConvDilations
,
GemmNRepeat
,
BPerBlock
,
GemmMPerThreadSubC
,
KPerBlock
,
GemmNPerThreadSubC
,
EPerBlock
,
GemmMLevel0Cluster
,
GemmNRepeat
,
GemmNLevel0Cluster
,
GemmMPerThreadSubC
,
GemmMLevel1Cluster
,
GemmNPerThreadSubC
,
GemmNLevel1Cluster
,
GemmMLevel0Cluster
,
GemmKPerThreadLoop
,
GemmNLevel0Cluster
,
GemmDataPerReadA
,
GemmMLevel1Cluster
,
GemmDataPerReadB
,
GemmNLevel1Cluster
,
InBlockCopySubLengths_E_N1_B_N2
,
GemmKPerThreadLoop
,
InBlockCopyClusterLengths_E_N1_B_N2
,
GemmDataPerReadA
,
InBlockCopyThreadClusterArrangeOrder
,
GemmDataPerReadB
,
InBlockCopySrcAccessOrder
,
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopyDstAccessOrder
,
InBlockCopyClusterLengths_E_N1_B_N2
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopyDstDataPerWrite_N2
,
InBlockCopySrcAccessOrder
,
WeiBlockCopySubLengths_E_K
,
InBlockCopyDstAccessOrder
,
WeiBlockCopyClusterLengths_E_K
,
InBlockCopySrcDataPerRead_B
,
WeiBlockCopyThreadClusterArrangeOrder
,
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopyDstDataPerWrite_K
>
{};
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
{};
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
...
...
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp
View file @
9b280cc5
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
#include "device.hpp"
#include "device.hpp"
#include "tensor.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
//#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp"
#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp"
#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp"
template
<
typename
T
,
template
<
typename
T
,
...
@@ -24,7 +23,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc,
...
@@ -24,7 +23,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc,
ConvDilations
,
ConvDilations
,
LeftPads
,
LeftPads
,
RightPads
,
RightPads
,
index_t
nrepeat
)
ck
::
index_t
nrepeat
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
...
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
9b280cc5
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
#include "device.hpp"
#include "device.hpp"
#include "tensor.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
//#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp"
using
namespace
ck
;
using
namespace
ck
;
...
@@ -22,7 +21,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -22,7 +21,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
Tensor
<
T
>&
out_nkhw
,
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvStrides
,
ConvDilations
,
ConvDilations
,
index_t
nrepeat
)
ck
::
index_t
nrepeat
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded.hpp
View file @
9b280cc5
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
#include "device.hpp"
#include "device.hpp"
#include "tensor.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
//#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp"
template
<
class
T
,
template
<
class
T
,
...
@@ -24,7 +23,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded(InDesc,
...
@@ -24,7 +23,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded(InDesc,
ConvDilations
,
ConvDilations
,
LeftPads
,
LeftPads
,
RightPads
,
RightPads
,
index_t
nrepeat
)
ck
::
index_t
nrepeat
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
...
driver/include/device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
deleted
100644 → 0
View file @
98a2cfcc
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp"
using
namespace
ck
;
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
LowerPads
,
class
UpperPads
>
void
device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
LowerPads
,
UpperPads
,
index_t
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
// reorder weight
auto
wei_cyxk_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Y
,
X
,
K
>
{});
ostream_ConstantTensorDescriptor
(
wei_cyxk_desc
,
std
::
cout
<<
"wei_cyxk_desc: "
);
Tensor
<
T
>
wei_cyxk
(
make_TensorDescriptor
(
wei_cyxk_desc
));
auto
f_reorder_kcyx2cyxk
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
wei_cyxk
(
c
,
y
,
x
,
k
)
=
wei_kcyx
(
k
,
c
,
y
,
x
);
};
make_ParallelTensorFunctor
(
f_reorder_kcyx2cyxk
,
K
,
C
,
Y
,
X
)(
std
::
thread
::
hardware_concurrency
());
// reorder input
auto
in_chwn_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Hi
,
Wi
,
N
>
{});
ostream_ConstantTensorDescriptor
(
in_chwn_desc
,
std
::
cout
<<
"in_chwn_desc: "
);
Tensor
<
T
>
in_chwn
(
make_TensorDescriptor
(
in_chwn_desc
));
auto
f_reorder_nchw2chwn
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
in_chwn
(
c
,
hi
,
wi
,
n
)
=
in_nchw
(
n
,
c
,
hi
,
wi
);
};
make_ParallelTensorFunctor
(
f_reorder_nchw2chwn
,
N
,
C
,
Hi
,
Wi
)(
std
::
thread
::
hardware_concurrency
());
// output
auto
out_khwn_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
Ho
,
Wo
,
N
>
{});
ostream_ConstantTensorDescriptor
(
out_khwn_desc
,
std
::
cout
<<
"out_khwn_desc: "
);
Tensor
<
T
>
out_khwn
(
make_TensorDescriptor
(
out_khwn_desc
));
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_chwn_device_buf
(
data_sz
*
in_chwn
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_cyxk_device_buf
(
data_sz
*
wei_cyxk
.
mDesc
.
GetElementSpace
());
DeviceMem
out_khwn_device_buf
(
data_sz
*
out_khwn
.
mDesc
.
GetElementSpace
());
in_chwn_device_buf
.
ToDevice
(
in_chwn
.
mData
.
data
());
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
#if 0
constexpr index_t NPerBlock = 1;
constexpr index_t KPerBlock = 1;
constexpr index_t CPerBlock = 1;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 1;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t WeiBlockCopyThreadPerDim0 = 1;
constexpr index_t WeiBlockCopyThreadPerDim1 = 1;
constexpr index_t BlockSize = 8;
#elif
1
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr
index_t
NPerBlock
=
8
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 5x5, 36x36
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 7x7, 38x38
constexpr
index_t
NPerBlock
=
8
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 3x3, 56x56
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
// 3x3 56x56, NKC = 16,256,128, with padding
// 3x3 28x28, NKC = 16,512,256, with padding
// 3x3 20x84, NKC = 16,256,256, with padding
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
2
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
64
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 5x5 filter, 20x84 image, 1x1 padding
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
1
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// 5x5 filter, 28x28 image, 2x2 padding
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 1x1, 28x28
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
2
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
index_t
BlockSize
=
128
;
#endif
constexpr
index_t
GridSize
=
((
N
+
NPerBlock
-
1
)
/
NPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
)
*
((
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
)
*
((
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_chwn_desc
),
decltype
(
wei_cyxk_desc
),
decltype
(
out_khwn_desc
),
LowerPads
,
UpperPads
,
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
CPerThread
,
HoPerThread
,
WoPerThread
,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_cyxk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
out_khwn_device_buf
.
FromDevice
(
out_khwn
.
mData
.
data
());
// reorder output
auto
f_reorder_khwn2nkhw
=
[
&
](
auto
k
,
auto
ho
,
auto
wo
,
auto
n
)
{
out_nkhw
(
n
,
k
,
ho
,
wo
)
=
out_khwn
(
k
,
ho
,
wo
,
n
);
};
make_ParallelTensorFunctor
(
f_reorder_khwn2nkhw
,
K
,
Ho
,
Wo
,
N
)(
std
::
thread
::
hardware_concurrency
());
}
driver/src/driver.cpp
View file @
9b280cc5
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#include "device.hpp"
#include "device.hpp"
#include "conv_common.hpp"
#include "conv_common.hpp"
#include "host_conv.hpp"
#include "host_conv.hpp"
#include "device_convolution_direct_v2_nchw_kcyx_nkhw.hpp"
//
#include "device_convolution_direct_v2_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp"
//#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp"
//#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp"
//#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp"
//#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
...
@@ -448,7 +448,7 @@ int main(int argc, char* argv[])
...
@@ -448,7 +448,7 @@ int main(int argc, char* argv[])
ConvStrides
{},
ConvStrides
{},
ConvDilations
{},
ConvDilations
{},
nrepeat
);
nrepeat
);
#elif
1
#elif
0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
@@ -490,7 +490,7 @@ int main(int argc, char* argv[])
...
@@ -490,7 +490,7 @@ int main(int argc, char* argv[])
ConvStrides
{},
ConvStrides
{},
ConvDilations
{},
ConvDilations
{},
nrepeat
);
nrepeat
);
#elif
0
#elif
1
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment