Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f2ed2b3b
Commit
f2ed2b3b
authored
Jan 13, 2023
by
Paul
Browse files
Format
parent
692ce4b0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
13 deletions
+12
-13
src/include/migraphx/convolution.hpp
src/include/migraphx/convolution.hpp
+12
-13
No files found.
src/include/migraphx/convolution.hpp
View file @
f2ed2b3b
...
@@ -34,11 +34,11 @@
...
@@ -34,11 +34,11 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Output
,
class
T
,
class
Padding
,
class
Stride
>
template
<
class
Output
,
class
T
,
class
Padding
,
class
Stride
>
void
convolution
(
Output
output
,
T
input
,
T
weights
,
Padding
padding
,
Stride
stride
,
int
group
)
void
convolution
(
Output
output
,
T
input
,
T
weights
,
Padding
padding
,
Stride
stride
,
int
group
)
{
{
auto
output_shape
=
output
.
get_shape
();
auto
output_shape
=
output
.
get_shape
();
auto
in_lens
=
input
.
get_shape
().
lens
();
auto
in_lens
=
input
.
get_shape
().
lens
();
auto
wei_lens
=
weights
.
get_shape
().
lens
();
auto
wei_lens
=
weights
.
get_shape
().
lens
();
auto
wei_n
=
wei_lens
[
0
];
auto
wei_n
=
wei_lens
[
0
];
...
@@ -68,22 +68,21 @@ void convolution(Output output, T input, T weights, Padding padding, Stride stri
...
@@ -68,22 +68,21 @@ void convolution(Output output, T input, T weights, Padding padding, Stride stri
std
::
vector
<
std
::
ptrdiff_t
>
idx
(
idx_o
.
begin
(),
idx_o
.
end
());
std
::
vector
<
std
::
ptrdiff_t
>
idx
(
idx_o
.
begin
(),
idx_o
.
end
());
idx
[
1
]
=
in_ch
;
idx
[
1
]
=
in_ch
;
std
::
transform
(
idx_win
.
begin
()
+
1
,
std
::
transform
(
idx_win
.
begin
()
+
1
,
idx_win
.
end
(),
idx_win
.
end
(),
win_start
.
begin
(),
win_start
.
begin
(),
idx
.
begin
()
+
2
,
idx
.
begin
()
+
2
,
[](
std
::
ptrdiff_t
ii
,
std
::
ptrdiff_t
jj
)
{
return
ii
+
jj
;
});
[](
std
::
ptrdiff_t
ii
,
std
::
ptrdiff_t
jj
)
{
return
ii
+
jj
;
});
std
::
vector
<
std
::
ptrdiff_t
>
idx_wei
(
idx_o
.
size
());
std
::
vector
<
std
::
ptrdiff_t
>
idx_wei
(
idx_o
.
size
());
idx_wei
[
0
]
=
w
;
idx_wei
[
0
]
=
w
;
std
::
copy
(
idx_win
.
begin
(),
idx_win
.
end
(),
idx_wei
.
begin
()
+
1
);
std
::
copy
(
idx_win
.
begin
(),
idx_win
.
end
(),
idx_wei
.
begin
()
+
1
);
if
(
std
::
all_of
(
idx
.
begin
()
+
2
,
idx
.
end
(),
[
&
](
auto
ii
)
{
return
ii
>=
0
;
})
and
if
(
std
::
all_of
(
idx
.
begin
()
+
2
,
idx
.
end
(),
[
&
](
auto
ii
)
{
return
ii
>=
0
;
})
and
std
::
equal
(
idx
.
begin
(),
std
::
equal
(
idx
.
begin
(),
idx
.
end
(),
idx
.
end
(),
in_lens
.
begin
(),
in_lens
.
begin
(),
in_lens
.
end
(),
in_lens
.
end
(),
std
::
less
<
std
::
ptrdiff_t
>
{}))
std
::
less
<
std
::
ptrdiff_t
>
{}))
{
{
acc
+=
acc
+=
input
(
idx
.
begin
(),
idx
.
end
())
*
weights
(
idx_wei
.
begin
(),
idx_wei
.
end
());
input
(
idx
.
begin
(),
idx
.
end
())
*
weights
(
idx_wei
.
begin
(),
idx_wei
.
end
());
}
}
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment