Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
fbef5744
Commit
fbef5744
authored
Nov 27, 2018
by
Khalique
Browse files
add workaround for scalar
parent
71276f4d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
25 deletions
+15
-25
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+15
-25
No files found.
src/onnx/onnx.cpp
View file @
fbef5744
...
...
@@ -137,36 +137,21 @@ struct onnx_parser
const
std
::
vector
<
std
::
size_t
>*
s0
=
&
args
[
0
]
->
get_shape
().
lens
();
const
std
::
vector
<
std
::
size_t
>*
s1
=
&
args
[
1
]
->
get_shape
().
lens
();
bool
swapped
=
false
;
// Make sure s0 is the smaller size
if
(
s0
->
size
()
>
s1
->
size
())
{
std
::
swap
(
s0
,
s1
);
swapped
=
true
;
}
std
::
vector
<
std
::
size_t
>
output_lens
(
s1
->
size
());
// if (s0->size() == 0)
// {
// shape s = swapped ? args[0]->get_shape() : args[1]->get_shape();
// auto l0 = prog.add_instruction(migraphx::op::scalar{s}, 1.0f);
// return prog.add_instruction(x, l0, args[1]);
// }
// else
// {
// Copy the larger vector to output_lens
auto
offset
=
s1
->
size
()
-
s0
->
size
();
std
::
transform
(
s0
->
begin
(),
s0
->
end
(),
s1
->
begin
()
+
offset
,
output_lens
.
begin
()
+
offset
,
[](
auto
a
,
auto
b
)
{
return
std
::
max
(
a
,
b
);
});
auto
l0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
args
[
0
]);
auto
l1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
args
[
1
]);
return
prog
.
add_instruction
(
x
,
l0
,
l1
);
// }
auto
offset
=
s1
->
size
()
-
s0
->
size
();
std
::
transform
(
s0
->
begin
(),
s0
->
end
(),
s1
->
begin
()
+
offset
,
output_lens
.
begin
()
+
offset
,
[](
auto
a
,
auto
b
)
{
return
std
::
max
(
a
,
b
);
});
auto
l0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
args
[
0
]);
auto
l1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
args
[
1
]);
return
prog
.
add_instruction
(
x
,
l0
,
l1
);
}
else
{
...
...
@@ -602,6 +587,11 @@ struct onnx_parser
static
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
{
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if
(
dims
.
size
()
==
0
)
{
dims
=
{
1
};
}
if
(
t
.
has_raw_data
())
{
const
std
::
string
&
s
=
t
.
raw_data
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment