Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1dca597f
Commit
1dca597f
authored
Oct 01, 2018
by
Scott Thornton
Browse files
Formatting
parent
8e4b1022
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
6 deletions
+9
-6
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+7
-5
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+2
-1
No files found.
src/include/migraph/operators.hpp
View file @
1dca597f
...
...
@@ -622,13 +622,14 @@ struct broadcast
std
::
string
name
()
const
{
return
"broadcast"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
t
=
inputs
.
at
(
0
).
type
();
auto
input
=
inputs
.
at
(
0
);
auto
t
=
inputs
.
at
(
0
).
type
();
auto
input
=
inputs
.
at
(
0
);
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_shape
.
lens
().
size
(),
0
);
if
(
std
::
all_of
(
broadcast_shape
.
lens
().
cbegin
(),
broadcast_shape
.
lens
().
cend
(),
[
&
](
auto
x
)
{
return
x
==
1
;
}))
if
(
std
::
all_of
(
broadcast_shape
.
lens
().
cbegin
(),
broadcast_shape
.
lens
().
cend
(),
[
&
](
auto
x
)
{
return
x
==
1
;
}))
{
if
(
axis
!=
0
)
MIGRAPH_THROW
(
"when broadcasting tensor of size 1, axis should be 0"
);
...
...
@@ -637,7 +638,8 @@ struct broadcast
else
{
assert
(
broadcast_shape
.
lens
().
size
()
-
axis
>=
input
.
lens
().
size
());
if
(
!
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
broadcast_shape
.
lens
().
begin
()
+
axis
))
if
(
!
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
broadcast_shape
.
lens
().
begin
()
+
axis
))
MIGRAPH_THROW
(
"when broadcasting success sizes must match"
);
std
::
copy
(
input
.
strides
().
begin
(),
input
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
return
{
t
,
broadcast_shape
.
lens
(),
std
::
move
(
bcast_strides
)};
...
...
src/onnx/onnx.cpp
View file @
1dca597f
...
...
@@ -93,7 +93,8 @@ struct onnx_parser
uint64_t
axis
=
(
contains
(
attributes
,
"axis"
))
?
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
()
:
0
;
auto
l
=
prog
.
add_instruction
(
op
::
broadcast
{
axis
,
args
[
0
]
->
get_shape
()},
args
[
1
]);
auto
l
=
prog
.
add_instruction
(
op
::
broadcast
{
axis
,
args
[
0
]
->
get_shape
()},
args
[
1
]);
return
prog
.
add_instruction
(
x
,
args
[
0
],
l
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment