Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a47845f2
Commit
a47845f2
authored
Jun 07, 2019
by
Jing Zhang
Browse files
fix
parent
de6f254d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
15 deletions
+28
-15
driver/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
...er/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
+1
-1
driver/driver.hip.cpp
driver/driver.hip.cpp
+27
-14
No files found.
driver/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
View file @
a47845f2
...
...
@@ -88,7 +88,7 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
#endif
constexpr
index_t
GridSize
=
((
B
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
((
B
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
)
/
(
Strides
{}.
Get
(
I1
)
*
Strides
{}.
Get
(
I0
))
;
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
...
...
driver/driver.hip.cpp
View file @
a47845f2
...
...
@@ -16,6 +16,17 @@
#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp"
struct
GeneratorTensor_0
{
template
<
class
...
Is
>
double
operator
()(
Is
...
is
)
{
return
0
;
}
};
struct
GeneratorTensor_1
{
template
<
class
...
Is
>
...
...
@@ -196,12 +207,14 @@ void host_direct_convolution_back(Tensor<TOut>& in_nchw,
{
for
(
int
y
=
0
;
y
<
wei_kcyx
.
mDesc
.
GetLengths
()[
2
];
++
y
)
{
int
ho
=
(
hi
-
y
*
dilation_h
+
h_pad_low
)
/
stride_h
;
int
ho_
=
(
hi
-
y
*
dilation_h
+
h_pad_low
);
int
ho
=
ho_
/
stride_h
;
for
(
int
x
=
0
;
x
<
wei_kcyx
.
mDesc
.
GetLengths
()[
3
];
++
x
)
{
int
wo
=
(
wi
-
x
*
dilation_w
+
w_pad_low
)
/
stride_w
;
if
(
ho
>=
0
&&
hi
<
out_nkhw
.
mDesc
.
GetLengths
()[
2
]
&&
wo
>=
0
&&
wo
<
out_nkhw
.
mDesc
.
GetLengths
()[
3
]
&&
ho
%
stride_h
==
0
&&
wo
%
stride_w
==
0
)
int
wo_
=
(
wi
-
x
*
dilation_w
+
w_pad_low
);
int
wo
=
wo_
/
stride_w
;
if
(
ho
>=
0
&&
ho
<
out_nkhw
.
mDesc
.
GetLengths
()[
2
]
&&
wo
>=
0
&&
wo
<
out_nkhw
.
mDesc
.
GetLengths
()[
3
]
&&
ho_
%
stride_h
==
0
&&
wo_
%
stride_w
==
0
)
{
v
+=
double
(
out_nkhw
(
n
,
k
,
ho
,
wo
))
*
double
(
wei_kcyx
(
k
,
c
,
y
,
x
));
}
...
...
@@ -489,12 +502,12 @@ int main(int argc, char* argv[])
constexpr
index_t
WDilation
=
1
;
constexpr
index_t
Direction
=
2
;
//1: Forward; 2:Backward
#if
1
#if
0
constexpr index_t N = 8;
constexpr index_t C = 128;
constexpr
index_t
HI
=
16
;
constexpr
index_t
WI
=
16
;
constexpr
index_t
K
=
1
28
;
constexpr index_t HI =
2
;
constexpr index_t WI =
32
;
constexpr index_t K = 1
6
;
constexpr index_t Y = 1;
constexpr index_t X = 1;
...
...
@@ -706,9 +719,9 @@ int main(int argc, char* argv[])
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_3
{},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
#elif 1
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_
2
{
-
5
,
5
},
num_thread
);
out_nkhw
.
GenerateTensorValue
(
GeneratorTensor_
2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_
2
{
-
5
,
5
},
num_thread
);
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_
0
{
},
num_thread
);
out_nkhw
.
GenerateTensorValue
(
GeneratorTensor_
1
{
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_
1
{
},
num_thread
);
#elif 0
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
1
,
5
},
num_thread
);
...
...
@@ -785,10 +798,10 @@ int main(int argc, char* argv[])
}
#if 0
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
//
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
LogRange(std::cout << "
out
_n
k
hw_host : ",
out
_n
k
hw
_host
.mData, ",") << std::endl;
LogRange(std::cout << "
out
_n
k
hw_device: ",
out
_n
k
hw_device.mData, ",") << std::endl;
LogRange(std::cout << "
in
_n
c
hw_host : ",
in
_n
c
hw.mData, ",") << std::endl;
LogRange(std::cout << "
in
_n
c
hw_device: ",
in
_n
c
hw_device.mData, ",") << std::endl;
#endif
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment