Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3db703df
Unverified
Commit
3db703df
authored
Jul 03, 2019
by
mvermeulen
Committed by
GitHub
Jul 03, 2019
Browse files
Merge pull request #294 from ROCmSoftwarePlatform/multibroadcast_bug
Fix a bug in the multibroadcast
parents
51f264a6
93d44e6e
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
52 additions
and
19 deletions
+52
-19
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+17
-3
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+9
-1
test/onnx/implicit_bcast_test.onnx
test/onnx/implicit_bcast_test.onnx
+10
-9
test/onnx/implicit_sub_bcast_test.onnx
test/onnx/implicit_sub_bcast_test.onnx
+4
-4
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+2
-2
test/op_shape_test.cpp
test/op_shape_test.cpp
+10
-0
No files found.
src/include/migraphx/op/multibroadcast.hpp
View file @
3db703df
...
...
@@ -35,14 +35,28 @@ struct multibroadcast
auto
input
=
inputs
.
at
(
0
);
if
(
input
.
lens
().
empty
())
MIGRAPHX_THROW
(
"inputs dimensions should be > 0"
);
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should be > 0"
);
}
if
(
input
.
lens
().
size
()
>
output_lens
.
size
())
MIGRAPHX_THROW
(
"inputs dimensions should <= output size"
);
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should <= output size"
);
}
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
auto
offset
=
output_lens
.
size
()
-
input
.
lens
().
size
();
for
(
std
::
ptrdiff_t
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
if
(
output_lens
[
i
+
offset
]
!=
input
.
lens
()[
i
]
and
input
.
lens
()[
i
]
!=
1
)
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: input shape {"
+
to_string_range
(
input
.
lens
())
+
"} cannot be broadcasted to {"
+
to_string_range
(
output_lens
)
+
"}!"
);
}
}
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
for
(
std
::
ptrdiff_t
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
if
(
output_lens
[
i
+
offset
]
==
input
.
lens
()[
i
])
{
...
...
src/onnx/onnx.cpp
View file @
3db703df
...
...
@@ -182,7 +182,15 @@ struct onnx_parser
s0
.
end
(),
s1
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[](
auto
a
,
auto
b
)
{
return
std
::
max
(
a
,
b
);
});
[
&
](
auto
a
,
auto
b
)
{
if
(
a
!=
b
and
a
!=
1
and
b
!=
1
)
{
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTLEN: shape {"
+
to_string_range
(
s0
)
+
"} and {"
+
to_string_range
(
s1
)
+
"} mismatch!"
);
}
return
std
::
max
(
a
,
b
);
});
return
out_lens
;
}
...
...
test/onnx/implicit_bcast_test.onnx
View file @
3db703df
implicit_bcast-example:q
add2:u
0
1
2
"Add
test-multi_bcast
Z
1
out
"Add
subtraction2
Z
0
Z
1
Z
1
b
2
b
out
B
\ No newline at end of file
B
test/onnx/implicit_sub_bcast_test.onnx
View file @
3db703df
subtraction
2:q
add
2:q
0
1out"Subsubtraction2Z
...
...
@@ -10,11 +10,11 @@
Z
1
b
b
out
B
\ No newline at end of file
B
test/onnx/onnx_test.cpp
View file @
3db703df
...
...
@@ -350,7 +350,7 @@ TEST_CASE(implicit_add_bcast_test)
{
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}});
auto
l1
=
p
.
add_parameter
(
"1"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
}});
auto
l1
=
p
.
add_parameter
(
"1"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
1
}});
auto
l2
=
p
.
add_instruction
(
migraphx
::
op
::
multibroadcast
{{
2
,
3
,
4
,
5
}},
l0
);
auto
l3
=
p
.
add_instruction
(
migraphx
::
op
::
multibroadcast
{{
2
,
3
,
4
,
5
}},
l1
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
l2
,
l3
);
...
...
@@ -377,7 +377,7 @@ TEST_CASE(implicit_sub_bcast_test)
{
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}});
auto
l1
=
p
.
add_parameter
(
"1"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
}});
auto
l1
=
p
.
add_parameter
(
"1"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}});
auto
l2
=
p
.
add_instruction
(
migraphx
::
op
::
multibroadcast
{{
2
,
3
,
4
,
5
}},
l0
);
auto
l3
=
p
.
add_instruction
(
migraphx
::
op
::
multibroadcast
{{
2
,
3
,
4
,
5
}},
l1
);
p
.
add_instruction
(
migraphx
::
op
::
sub
{},
l2
,
l3
);
...
...
test/op_shape_test.cpp
View file @
3db703df
...
...
@@ -227,6 +227,16 @@ TEST_CASE(multibroadcast)
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{}};
throws_shape
(
migraphx
::
op
::
multibroadcast
{
lens
},
input
);
}
{
std
::
vector
<
std
::
size_t
>
lens
{
2
,
3
,
4
,
5
};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
3
,
4
}};
throws_shape
(
migraphx
::
op
::
multibroadcast
{
lens
},
input
);
}
{
std
::
vector
<
std
::
size_t
>
lens
{
2
,
3
,
4
,
5
};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
throws_shape
(
migraphx
::
op
::
multibroadcast
{
lens
},
input
);
}
}
TEST_CASE
(
broadcast
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment