Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b530abc4
Commit
b530abc4
authored
Jun 05, 2019
by
Jing Zhang
Browse files
add support of stride
parent
eafdabba
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
68 additions
and
27 deletions
+68
-27
driver/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
...er/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
+3
-1
driver/driver.hip.cpp
driver/driver.hip.cpp
+31
-15
src/include/conv_common.hip.hpp
src/include/conv_common.hip.hpp
+7
-4
src/include/gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp
...implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp
+27
-7
No files found.
driver/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
View file @
b530abc4
...
...
@@ -5,12 +5,13 @@
#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
Strides
>
void
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Strides
,
Tensor
<
T
>&
out_nkhw
,
index_t
nrepeat
)
{
...
...
@@ -100,6 +101,7 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
#endif
<
GridSize
,
BlockSize
,
Strides
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
...
...
driver/driver.hip.cpp
View file @
b530abc4
...
...
@@ -103,12 +103,13 @@ auto make_TensorDescriptor(TConstTensorDesc)
return
TensorDescriptor
(
lengths
,
strides
);
}
template
<
class
TIn
,
class
TWei
,
class
TOut
,
class
LowerPads
,
class
UpperPads
>
template
<
class
TIn
,
class
TWei
,
class
TOut
,
class
LowerPads
,
class
UpperPads
,
class
Strides
>
void
host_direct_convolution
(
const
Tensor
<
TIn
>&
in_nchw
,
const
Tensor
<
TWei
>&
wei_kcyx
,
Tensor
<
TOut
>&
out_nkhw
,
LowerPads
,
UpperPads
)
UpperPads
,
Strides
)
{
index_t
h_pad_low
=
LowerPads
{}.
Get
(
Number
<
0
>
{});
index_t
w_pad_low
=
LowerPads
{}.
Get
(
Number
<
1
>
{});
...
...
@@ -116,16 +117,19 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
index_t
h_pad_up
=
UpperPads
{}.
Get
(
Number
<
0
>
{});
index_t
w_pad_up
=
UpperPads
{}.
Get
(
Number
<
1
>
{});
index_t
stride_h
=
Strides
{}.
Get
(
Number
<
0
>
{});
index_t
stride_w
=
Strides
{}.
Get
(
Number
<
1
>
{});
auto
f
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
double
v
=
0
;
for
(
int
c
=
0
;
c
<
wei_kcyx
.
mDesc
.
GetLengths
()[
1
];
++
c
)
{
for
(
int
y
=
0
;
y
<
wei_kcyx
.
mDesc
.
GetLengths
()[
2
];
++
y
)
{
int
hi
=
ho
+
y
-
h_pad_low
;
int
hi
=
ho
*
stride_h
+
y
-
h_pad_low
;
for
(
int
x
=
0
;
x
<
wei_kcyx
.
mDesc
.
GetLengths
()[
3
];
++
x
)
{
int
wi
=
wo
+
x
-
w_pad_low
;
int
wi
=
wo
*
stride_w
+
x
-
w_pad_low
;
if
(
hi
>=
0
&&
hi
<
in_nchw
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
in_nchw
.
mDesc
.
GetLengths
()[
3
])
{
...
...
@@ -408,14 +412,16 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int
main
(
int
argc
,
char
*
argv
[])
{
constexpr
index_t
U
=
2
;
constexpr
index_t
V
=
2
;
#if 0
constexpr index_t N = 8;
constexpr index_t C = 16;
constexpr index_t HI =
3
;
constexpr index_t WI = 1
8
;
constexpr index_t HI =
16
;
constexpr index_t WI = 1
6
;
constexpr index_t K = 128;
constexpr index_t Y =
3
;
constexpr index_t X =
3
;
constexpr index_t Y =
1
;
constexpr index_t X =
1
;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
...
...
@@ -443,7 +449,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif
1
#elif
0
// 3x3 filter, 28x28 image
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
...
...
@@ -455,7 +461,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif
0
#elif
1
// 1x1 filter, 28x28 image
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
...
...
@@ -580,10 +586,12 @@ int main(int argc, char* argv[])
auto
lower_pads
=
Sequence
<
HPad
,
WPad
>
{};
auto
upper_pads
=
Sequence
<
HPad
,
WPad
>
{};
auto
strides
=
Sequence
<
U
,
V
>
{};
auto
in_nchw_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
auto
wei_kcyx_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
K
,
C
,
Y
,
X
>
{});
auto
out_nkhw_desc
=
get_convolution_with_padding_output_default_4d_tensor_descriptor
(
in_nchw_desc
,
wei_kcyx_desc
,
lower_pads
,
upper_pad
s
);
auto
out_nkhw_desc
=
get_convolution_output_default_4d_tensor_descriptor
(
in_nchw_desc
,
wei_kcyx_desc
,
stride
s
);
ostream_ConstantTensorDescriptor
(
in_nchw_desc
,
std
::
cout
<<
"in_nchw_desc: "
);
ostream_ConstantTensorDescriptor
(
wei_kcyx_desc
,
std
::
cout
<<
"wei_kcyx_desc: "
);
...
...
@@ -651,7 +659,14 @@ int main(int argc, char* argv[])
#elif 1
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
#endif
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
strides
,
out_nkhw_device
,
nrepeat
);
#elif 1
device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded
(
in_nchw_desc
,
...
...
@@ -667,7 +682,7 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
#if
1
#if
0
if(Y == 3 && X == 3)
{
host_winograd_3x3_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads);
...
...
@@ -675,7 +690,8 @@ int main(int argc, char* argv[])
else
#endif
{
host_direct_convolution
(
in_nchw
,
wei_kcyx
,
out_nkhw_host
,
lower_pads
,
upper_pads
);
host_direct_convolution
(
in_nchw
,
wei_kcyx
,
out_nkhw_host
,
lower_pads
,
upper_pads
,
strides
);
}
check_error
(
out_nkhw_host
,
out_nkhw_device
);
...
...
src/include/conv_common.hip.hpp
View file @
b530abc4
...
...
@@ -2,8 +2,8 @@
#include "ConstantTensorDescriptor.hip.hpp"
// this is ugly, only for 4d
template
<
class
InDesc
,
class
WeiDesc
>
constexpr
auto
get_convolution_output_default_4d_tensor_descriptor
(
InDesc
,
WeiDesc
)
template
<
class
InDesc
,
class
WeiDesc
,
class
Strides
>
constexpr
auto
get_convolution_output_default_4d_tensor_descriptor
(
InDesc
,
WeiDesc
,
Strides
)
{
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
...
...
@@ -26,8 +26,11 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(InDesc, WeiDe
constexpr
auto
Y
=
wei_desc
.
GetLength
(
I2
);
constexpr
auto
X
=
wei_desc
.
GetLength
(
I3
);
constexpr
auto
HO
=
HI
+
1
-
Y
;
constexpr
auto
WO
=
WI
+
1
-
X
;
constexpr
index_t
U
=
Strides
{}.
Get
(
I0
);
constexpr
index_t
V
=
Strides
{}.
Get
(
I1
);
constexpr
auto
HO
=
(
HI
-
Y
)
/
U
+
1
;
constexpr
auto
WO
=
(
WI
-
X
)
/
V
+
1
;
return
make_ConstantTensorDescriptor_packed
(
Sequence
<
N
,
K
,
HO
,
WO
>
{});
}
...
...
src/include/gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp
View file @
b530abc4
...
...
@@ -10,6 +10,7 @@
// define B = merge(N, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Strides
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
...
...
@@ -67,7 +68,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
// to-do: backward data: 1) ckyx: yx unfold, 2) merge cyx = e, 3 out = ek
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
...
...
@@ -109,19 +111,37 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
in_n_c_h_w_global_desc
.
Slice
(
I2
,
Number
<
Ho
>
{})
.
Slice
(
I3
,
Number
<
Wo
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
// constexpr auto in_n0_n1_n2_h_w_global_desc = in_n_c_h_w_global_desc.Slice(I2,
// Number<Ho>{})
//.Slice(I3, Number<Wo>{})
//.Fold(I0, Number<N1>{}, Number<N2>{})
//.Extract(Sequence<0, 1, 2, 4, 5>{});
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
in_n_c_h_w_global_desc
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
constexpr
auto
new_lengths
=
Sequence
<
N0
,
N1
,
N2
,
Ho
,
Wo
>
{};
constexpr
auto
new_strides
=
Sequence
<
in_n0_n1_n2_h_w_global_desc
.
GetStride
(
I0
),
in_n0_n1_n2_h_w_global_desc
.
GetStride
(
I1
),
in_n0_n1_n2_h_w_global_desc
.
GetStride
(
I2
),
in_n0_n1_n2_h_w_global_desc
.
GetStride
(
I3
)
*
Strides
{}.
Get
(
I0
),
in_n0_n1_n2_h_w_global_desc
.
GetStride
(
I4
)
*
Strides
{}.
Get
(
I1
)
>
{};
constexpr
auto
in_n0_n1_n2_h_w_new_global_desc
=
make_ConstantTensorDescriptor
(
new_lengths
,
new_strides
);
// batch descritpor for device memory
// to-do: add dilation: keep lengths, modify strides
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
Slice
(
I2
,
Number
<
Y
>
{})
.
Slice
(
I3
,
Number
<
X
>
{})
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr
auto
in_e_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
in_c_y_x_global_desc
.
Embed
(
in_n0_n1_n2_h_w_global_desc
),
in_c_y_x_global_desc
.
Embed
(
in_n0_n1_n2_h_w_
new_
global_desc
),
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
6
,
7
>
{},
...
...
@@ -246,7 +266,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// choose GEMM implementation here
const
auto
run_blockwise_gemm
=
[
&
](
auto
...
Xs
)
{
#if
1
#if
0
return blockwise_gemm.Run(Xs...);
#else
return
blockwise_gemm
.
Run_asm
(
Xs
...);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment