Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
79d9b108
Commit
79d9b108
authored
Mar 18, 2019
by
Chao Liu
Browse files
adding fp16 direct that reads pre-vectorized data
parent
28325204
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
92 additions
and
166 deletions
+92
-166
driver/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
...device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
+34
-15
src/include/blockwise_2d_tensor_op.hip.hpp
src/include/blockwise_2d_tensor_op.hip.hpp
+1
-1
src/include/blockwise_4d_tensor_op.hip.hpp
src/include/blockwise_4d_tensor_op.hip.hpp
+2
-2
src/include/common.hip.hpp
src/include/common.hip.hpp
+1
-91
src/include/config.h.in
src/include/config.h.in
+0
-2
src/include/functional.hip.hpp
src/include/functional.hip.hpp
+8
-0
src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
...se_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
+26
-20
src/include/threadwise_direct_convolution.hip.hpp
src/include/threadwise_direct_convolution.hip.hpp
+20
-35
No files found.
driver/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
View file @
79d9b108
...
...
@@ -13,8 +13,8 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
unsigned
nrepeat
)
{
constexpr
unsigned
NVector
=
1
;
using
vector_t
ype_t
=
vector_type
<
T
,
NVector
>
;
using
vector_
t
=
typename
vector_t
ype_t
::
Vect
orType
;
using
vector_t
=
vector_type
<
T
,
NVector
>
;
using
vector_
mem_t
=
typename
vector_t
::
Mem
or
y
Type
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -41,40 +41,41 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
auto
in_nchw_vec_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
N
,
C
/
NVector
,
Hi
,
Wi
>
{});
ostream_ConstantTensorDescriptor
(
in_nchw_vec_desc
,
std
::
cout
<<
"in_nchw_vec_desc: "
);
Tensor
<
vector_t
>
in_nchw_vec
(
make_TensorDescriptor
(
in_nchw_vec_desc
));
Tensor
<
vector_
mem_
t
>
in_nchw_vec
(
make_TensorDescriptor
(
in_nchw_vec_desc
));
auto
f_vectorized_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
h
,
auto
w
)
{
#if 1
in_nchw_vec
(
n
,
c
,
h
,
w
)
=
in_nchw
(
n
,
c
,
h
,
w
);
#else
in_nchw_vec
(
n
,
c
,
h
,
w
)
=
vector_
type_
t
::
p
ack
(
in_nchw
(
n
,
2
*
c
,
h
,
w
),
in_nchw
(
n
,
2
*
c
+
1
,
h
,
w
));
vector_t
::
P
ack
(
in_nchw
(
n
,
2
*
c
,
h
,
w
),
in_nchw
(
n
,
2
*
c
+
1
,
h
,
w
));
#endif
};
make_ParallelTensorFunctor
(
f_vectorized_nchw
,
N
,
C
,
Hi
,
Wi
)(
make_ParallelTensorFunctor
(
f_vectorized_nchw
,
N
,
C
/
NVector
,
Hi
,
Wi
)(
std
::
thread
::
hardware_concurrency
());
// vectorize weight
auto
wei_kcyx_vec_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
C
/
NVector
,
Y
,
X
>
{});
ostream_ConstantTensorDescriptor
(
wei_kcyx_vec_desc
,
std
::
cout
<<
"wei_kcyx_vec_desc: "
);
Tensor
<
vector_t
>
wei_kcyx_vec
(
make_TensorDescriptor
(
wei_kcyx_vec_desc
));
Tensor
<
vector_
mem_
t
>
wei_kcyx_vec
(
make_TensorDescriptor
(
wei_kcyx_vec_desc
));
auto
f_vectorized_kcyx
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
#if 1
wei_kcyx_vec
(
k
,
c
,
y
,
x
)
=
wei_kcyx
(
k
,
c
,
y
,
x
);
#else
wei_kcyx_vec
(
k
,
c
,
y
,
x
)
=
vector_
type_
t
::
p
ack
(
wei_kcyx
(
k
,
2
*
c
,
y
,
x
),
wei_kcyx
(
k
,
2
*
c
+
1
,
y
,
x
));
vector_t
::
P
ack
(
wei_kcyx
(
k
,
2
*
c
,
y
,
x
),
wei_kcyx
(
k
,
2
*
c
+
1
,
y
,
x
));
#endif
};
make_ParallelTensorFunctor
(
f_vectorized_kcyx
,
K
,
C
,
Y
,
X
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f_vectorized_kcyx
,
K
,
C
/
NVector
,
Y
,
X
)(
std
::
thread
::
hardware_concurrency
());
//
DeviceMem
in_nchw_vec_device_buf
(
sizeof
(
vector_t
)
*
in_nchw_vec
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_vec_device_buf
(
sizeof
(
vector_t
)
*
wei_kcyx_vec
.
mDesc
.
GetElementSpace
());
DeviceMem
in_nchw_vec_device_buf
(
sizeof
(
vector_
mem_
t
)
*
in_nchw_vec
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_vec_device_buf
(
sizeof
(
vector_
mem_
t
)
*
wei_kcyx_vec
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
sizeof
(
T
)
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_vec_device_buf
.
ToDevice
(
in_nchw_vec
.
mData
.
data
());
...
...
@@ -82,7 +83,7 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if 1
// 3x3, 34x34, 128 thread
// 3x3, 34x34, 128 thread
, fp32, vector = 1
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
4
;
...
...
@@ -96,24 +97,42 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
constexpr
unsigned
WoPerThread
=
2
;
constexpr
unsigned
InBlockCopyDataPerRead
=
2
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
2
;
constexpr
unsigned
BlockSize
=
128
;
#elif 1
// 3x3, 34x34, 128 thread, fp
16
// 3x3, 34x34, 128 thread, fp
32, vector = 2
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
unsigned
NPerThread
=
2
;
constexpr
unsigned
KPerThread
=
4
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
2
;
constexpr
unsigned
WoPerThread
=
2
;
constexpr
unsigned
InBlockCopyDataPerRead
=
2
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
2
;
constexpr
unsigned
BlockSize
=
128
;
#elif 1
// 3x3, 34x34, 128 thread, fp16
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
unsigned
NPerThread
=
2
;
constexpr
unsigned
KPerThread
=
4
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
2
;
constexpr
unsigned
WoPerThread
=
2
;
constexpr
unsigned
InBlockCopyDataPerRead
=
2
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
BlockSize
=
128
;
...
...
src/include/blockwise_2d_tensor_op.hip.hpp
View file @
79d9b108
...
...
@@ -373,7 +373,7 @@ template <unsigned BlockSize,
unsigned
DataPerRead
>
struct
Blockwise2dTensorCopy3
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
Vect
orType
;
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
Mem
or
y
Type
;
unsigned
mSrcMyThreadOffset
;
unsigned
mDstMyThreadOffset
;
...
...
src/include/blockwise_4d_tensor_op.hip.hpp
View file @
79d9b108
...
...
@@ -207,7 +207,7 @@ template <unsigned BlockSize,
unsigned
DataPerRead
>
struct
Blockwise4dTensorCopy1
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
Vect
orType
;
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
Mem
or
y
Type
;
__device__
constexpr
Blockwise4dTensorCopy1
()
{
...
...
@@ -444,7 +444,7 @@ template <unsigned BlockSize,
unsigned
DataPerRead
>
struct
Blockwise4dTensorCopy3
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
Vect
orType
;
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
Mem
or
y
Type
;
unsigned
mSrcMyThreadOffset
;
unsigned
mDstMyThreadOffset
;
...
...
src/include/common.hip.hpp
View file @
79d9b108
#pragma once
#include "data_type.hip.hpp"
#include "constant_integral.hip.hpp"
#include "Sequence.hip.hpp"
#include "Array.hip.hpp"
...
...
@@ -20,97 +21,6 @@ struct is_same<T, T>
static
const
bool
value
=
true
;
};
template
<
class
T
,
unsigned
N
>
struct
vector_type
{
};
template
<
>
struct
vector_type
<
float
,
1
>
{
using
VectorType
=
float
;
};
template
<
>
struct
vector_type
<
float
,
2
>
{
using
VectorType
=
float2
;
};
template
<
>
struct
vector_type
<
float
,
4
>
{
using
VectorType
=
float4
;
};
#if 0
template <>
struct vector_type<half_float::half, 1>
{
using VectorType = half_float::half;
};
template <>
struct vector_type<half_float::half, 2>
{
using VectorType = float;
};
template <>
struct vector_type<half_float::half, 4>
{
using VectorType = float2;
};
template <>
struct vector_type<half_float::half, 8>
{
using VectorType = float4;
};
#endif
#if 1
template
<
>
struct
vector_type
<
half
,
1
>
{
using
VectorType
=
half
;
__host__
__device__
static
VectorType
pack
(
half
s
)
{
return
s
;
}
};
template
<
>
struct
vector_type
<
half
,
2
>
{
using
VectorType
=
half2
;
union
Data
{
VectorType
vector
;
half
scalar
[
2
];
};
__host__
__device__
static
VectorType
pack
(
half
s0
,
half
s1
)
{
Data
data
;
data
.
scalar
[
0
]
=
s0
;
data
.
scalar
[
1
]
=
s1
;
return
data
.
vector
;
}
};
template
<
>
struct
vector_type
<
half
,
4
>
{
using
VectorType
=
float2
;
};
template
<
>
struct
vector_type
<
half
,
8
>
{
using
VectorType
=
float4
;
};
#endif
template
<
typename
T
>
__host__
__device__
constexpr
T
max
(
T
a
,
T
b
)
{
...
...
src/include/config.h.in
View file @
79d9b108
...
...
@@ -4,10 +4,8 @@
#if DEVICE_BACKEND_HIP
#include "hip/hip_runtime.h"
#include "half.hpp"
#elif DEVICE_BACKEND_CUDA
#include "cuda_runtime.h"
#include "nvToolsExt.h"
#include "helper_cuda.h"
#include "cuda_fp16.h"
#endif
src/include/functional.hip.hpp
View file @
79d9b108
...
...
@@ -47,3 +47,11 @@ struct static_const_reduce_n<1>
return
f
(
Number
<
0
>
{});
}
};
#if 0
template<class F>
__host__ __device__ constexpr auto unpacker(F f)
{
return [=](auto xs_array){ f(xs...); };
}
#endif
\ No newline at end of file
src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
View file @
79d9b108
...
...
@@ -27,12 +27,14 @@ template <class Float,
unsigned
BlockSize
,
unsigned
GridSize
>
__global__
void
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw
(
const
typename
vector_type
<
Float
,
ScalarPerVector
>::
VectorType
*
const
__restrict__
p_in_vec_global
,
const
typename
vector_type
<
Float
,
ScalarPerVector
>::
VectorType
*
const
__restrict__
p_wei_vec_global
,
const
typename
vector_type
<
Float
,
ScalarPerVector
>::
MemoryType
*
const
__restrict__
p_in_vec_global
,
const
typename
vector_type
<
Float
,
ScalarPerVector
>::
MemoryType
*
const
__restrict__
p_wei_vec_global
,
Float
*
const
__restrict__
p_out_global
)
{
using
scalar_t
=
Float
;
using
vector_t
=
typename
vector_type
<
scalar_t
,
ScalarPerVector
>::
Vect
orType
;
using
scalar_t
=
Float
;
using
vector_
mem_
t
=
typename
vector_type
<
scalar_t
,
ScalarPerVector
>::
Mem
or
y
Type
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -69,6 +71,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
// shared mem
constexpr
unsigned
in_block_size
=
in_nchw_vec_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
unsigned
wei_block_size
=
wei_kcyx_vec_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
...
...
@@ -76,8 +79,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
?
InBlockCopyDataPerRead
:
WeiBlockCopyDataPerRead
;
__shared__
vector_t
p_in_vec_block
[
max_align
*
((
in_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
vector_t
p_wei_vec_block
[
max_align
*
((
wei_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
vector_mem_t
p_in_vec_block
[
max_align
*
((
in_block_size
+
max_align
-
1
)
/
max_align
)];
__shared__
vector_mem_t
p_wei_vec_block
[
max_align
*
((
wei_block_size
+
max_align
-
1
)
/
max_align
)];
// threadwise tensors
constexpr
unsigned
HiPerThread
=
HoPerThread
+
Y
-
1
;
...
...
@@ -150,7 +155,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
constexpr
auto
blockwise_in_copy
=
Blockwise4dTensorCopy1
<
BlockSize
,
vector_t
,
vector_
mem_
t
,
decltype
(
in_nchw_vec_global_desc
),
decltype
(
in_nchw_vec_block_desc
),
decltype
(
in_nchw_vec_block_desc
.
GetLengths
()),
...
...
@@ -159,7 +164,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#if 0
constexpr auto blockwise_wei_copy =
Blockwise4dTensorCopy1<BlockSize,
vector_t,
vector_
mem_
t,
decltype(wei_kcyx_vec_global_desc),
decltype(wei_kcyx_vec_block_desc),
decltype(wei_kcyx_vec_block_desc.GetLengths()),
...
...
@@ -167,7 +172,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#elif
1
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
vector_t
,
vector_
mem_
t
,
decltype
(
wei_ke_vec_global_desc
),
decltype
(
wei_ke_vec_block_desc
),
decltype
(
wei_ke_vec_block_desc
.
GetLengths
()),
...
...
@@ -181,15 +186,16 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
{
// copy input tensor to LDS
blockwise_in_copy
.
Run
(
p_in_vec_global
+
in_nchw_vec_global_desc
.
Get1dIndex
(
n_block_data_begin
,
c_block_data_begin
,
hi_block_data_begin
,
wi_block_data_begin
),
blockwise_in_copy
.
Run
(
p_in_vec_global
+
in_nchw_vec_global_desc
.
Get1dIndex
(
n_block_data_begin
,
c_block_data_begin
,
hi_block_data_begin
,
wi_block_data_begin
),
p_in_vec_block
);
// copy weight tensor to LDS
blockwise_wei_copy
.
Run
(
p_wei_vec_global
+
wei_kcyx_vec_global_desc
.
Get1dIndex
(
k_block_data_begin
,
c_block_data_begin
,
0
,
0
),
k_block_data_begin
,
c_block_data_begin
,
0
,
0
),
p_wei_vec_block
);
__syncthreads
();
...
...
@@ -201,9 +207,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
threadwise_direct_convolution_2
(
in_nchw_vec_thread_block_desc
,
p_in_vec_block
+
in_nchw_vec_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data
,
hi_thread_data_begin
,
wi_thread_data_begin
),
c_thread_data
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wei_kcyx_vec_thread_block_desc
,
p_wei_vec_block
+
wei_kcyx_vec_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
...
...
@@ -213,9 +219,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
threadwise_direct_convolution_3
(
in_nchw_vec_thread_block_desc
,
p_in_vec_block
+
in_nchw_vec_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data
,
hi_thread_data_begin
,
wi_thread_data_begin
),
c_thread_data
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wei_kcyx_vec_thread_block_desc
,
p_wei_vec_block
+
wei_kcyx_vec_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
...
...
src/include/threadwise_direct_convolution.hip.hpp
View file @
79d9b108
...
...
@@ -2,13 +2,13 @@
#include "ConstantTensorDescriptor.hip.hpp"
// optimized for scenario if p_in, p_wei, p_out are in register
template
<
class
Floa
t
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
TInWei
,
class
TOu
t
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
__device__
void
threadwise_direct_convolution_1
(
InDesc
,
Float
*
const
__restrict__
p_in
,
TInWei
*
const
__restrict__
p_in
,
WeiDesc
,
Float
*
const
__restrict__
p_wei
,
TInWei
*
const
__restrict__
p_wei
,
OutDesc
,
Floa
t
*
__restrict__
p_out
)
TOu
t
*
__restrict__
p_out
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -51,25 +51,10 @@ __device__ void threadwise_direct_convolution_1(InDesc,
const
unsigned
out_index
=
out_desc
.
Get1dIndex
(
n
,
k
,
ho
,
wo
);
p_out
[
out_index
]
+=
p_wei
[
wei_index
]
*
p_in
[
in_index
];
#if 0
// if(threadIdx.x == 0)
{
printf("threadwise_direct_convolution: \t"
"threadIdx.x %u\t"
"out_index %u, p_out[out_index] %f, \t"
"wei_index %u, p_wei[wei_index] %f, \t"
"in_index %u, p_in[in_index] %f\n",
threadIdx.x,
out_index,
p_out[out_index],
wei_index,
p_wei[wei_index],
in_index,
p_in[in_index]);
}
#endif
fused_multiply_add
(
p_out
[
out_index
],
p_wei
[
wei_index
],
p_in
[
in_index
],
p_out
[
out_index
]);
}
}
}
...
...
@@ -81,13 +66,13 @@ __device__ void threadwise_direct_convolution_1(InDesc,
// Optimized for scenario if p_in and p_wei are in LDS, p_out are in register
// Copy in and wei into register before doing convolution
template
<
class
Floa
t
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
TInWei
,
class
TOu
t
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
__device__
void
threadwise_direct_convolution_2
(
InDesc
,
Float
*
const
__restrict__
p_in
,
TInWei
*
const
__restrict__
p_in
,
WeiDesc
,
Float
*
const
__restrict__
p_wei
,
TInWei
*
const
__restrict__
p_wei
,
OutDesc
,
Floa
t
*
__restrict__
p_out
)
TOu
t
*
__restrict__
p_out
)
{
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
...
...
@@ -97,8 +82,8 @@ __device__ void threadwise_direct_convolution_2(InDesc,
constexpr
auto
wei_reg_desc
=
make_ConstantTensorDescriptor
(
wei_desc
.
GetLengths
());
// register
Float
p_in_reg
[
in_reg_desc
.
GetElementSpace
()];
Float
p_wei_reg
[
wei_reg_desc
.
GetElementSpace
()];
TInWei
p_in_reg
[
in_reg_desc
.
GetElementSpace
()];
TInWei
p_wei_reg
[
wei_reg_desc
.
GetElementSpace
()];
// copy input tensor into register
threadwise_4d_tensor_copy
(
in_desc
,
p_in
,
in_reg_desc
,
p_in_reg
,
in_reg_desc
.
GetLengths
());
...
...
@@ -114,13 +99,13 @@ __device__ void threadwise_direct_convolution_2(InDesc,
// optimized for scenario where p_in and p_wei are in LDS, p_out is in register
// break down a non-1x1 convolution into a sequence of 1x1 convolutions,
// load 1x1 weight into register, and do 1x1 convolution in register.
template
<
class
Flo
at
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
D
at
a
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
__device__
void
threadwise_direct_convolution_3
(
InDesc
,
Flo
at
*
const
__restrict__
p_in
,
D
at
a
*
const
__restrict__
p_in
,
WeiDesc
,
Flo
at
*
const
__restrict__
p_wei
,
D
at
a
*
const
__restrict__
p_wei
,
OutDesc
,
Flo
at
*
__restrict__
p_out
)
D
at
a
*
__restrict__
p_out
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -139,8 +124,8 @@ __device__ void threadwise_direct_convolution_3(InDesc,
constexpr
auto
wei_reg_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
wei_desc
.
GetLength
(
I0
),
wei_desc
.
GetLength
(
I1
),
1
,
1
>
{});
Flo
at
p_in_reg
[
in_reg_desc
.
GetElementSpace
()];
Flo
at
p_wei_reg
[
wei_reg_desc
.
GetElementSpace
()];
D
at
a
p_in_reg
[
in_reg_desc
.
GetElementSpace
()];
D
at
a
p_wei_reg
[
wei_reg_desc
.
GetElementSpace
()];
constexpr
unsigned
in_w_new_read
=
1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment