Unverified Commit 3db703df authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge pull request #294 from ROCmSoftwarePlatform/multibroadcast_bug

Fix a bug in the multibroadcast
parents 51f264a6 93d44e6e
...@@ -35,14 +35,28 @@ struct multibroadcast ...@@ -35,14 +35,28 @@ struct multibroadcast
auto input = inputs.at(0); auto input = inputs.at(0);
if(input.lens().empty()) if(input.lens().empty())
MIGRAPHX_THROW("inputs dimensions should be > 0"); {
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0");
}
if(input.lens().size() > output_lens.size()) if(input.lens().size() > output_lens.size())
MIGRAPHX_THROW("inputs dimensions should <= output size"); {
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should <= output size");
}
std::vector<size_t> bcast_strides(output_lens.size(), 0);
auto offset = output_lens.size() - input.lens().size(); auto offset = output_lens.size() - input.lens().size();
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--) for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--)
{
if(output_lens[i + offset] != input.lens()[i] and input.lens()[i] != 1)
{
MIGRAPHX_THROW("MULTIBROADCAST: input shape {" + to_string_range(input.lens()) +
"} cannot be broadcasted to {" + to_string_range(output_lens) +
"}!");
}
}
std::vector<size_t> bcast_strides(output_lens.size(), 0);
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--)
{ {
if(output_lens[i + offset] == input.lens()[i]) if(output_lens[i + offset] == input.lens()[i])
{ {
......
...@@ -182,7 +182,15 @@ struct onnx_parser ...@@ -182,7 +182,15 @@ struct onnx_parser
s0.end(), s0.end(),
s1.begin() + offset, s1.begin() + offset,
out_lens.begin() + offset, out_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); }); [&](auto a, auto b) {
if(a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" +
to_string_range(s0) + "} and {" +
to_string_range(s1) + "} mismatch!");
}
return std::max(a, b);
});
return out_lens; return out_lens;
} }
......
implicit_bcast-example:q add2:u
 
0 0
12"Addtest-multi_bcastZ 1out"Add subtraction2Z
0 0
 
 
 
 
Z Z
1 1
 
 
b 
2 b
out
 
 
 
 
B B
\ No newline at end of file
 subtraction2:q add2:q
 
0 0
1out"Sub subtraction2Z 1out"Sub subtraction2Z
...@@ -10,11 +10,11 @@ ...@@ -10,11 +10,11 @@
Z Z
1 1
 
 
b b
out out
 
 
 
 
B B
\ No newline at end of file
...@@ -350,7 +350,7 @@ TEST_CASE(implicit_add_bcast_test) ...@@ -350,7 +350,7 @@ TEST_CASE(implicit_add_bcast_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0); auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, l2, l3); p.add_instruction(migraphx::op::add{}, l2, l3);
...@@ -377,7 +377,7 @@ TEST_CASE(implicit_sub_bcast_test) ...@@ -377,7 +377,7 @@ TEST_CASE(implicit_sub_bcast_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0); auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::sub{}, l2, l3); p.add_instruction(migraphx::op::sub{}, l2, l3);
......
...@@ -227,6 +227,16 @@ TEST_CASE(multibroadcast) ...@@ -227,6 +227,16 @@ TEST_CASE(multibroadcast)
migraphx::shape input{migraphx::shape::float_type, {}}; migraphx::shape input{migraphx::shape::float_type, {}};
throws_shape(migraphx::op::multibroadcast{lens}, input); throws_shape(migraphx::op::multibroadcast{lens}, input);
} }
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {3, 4}};
throws_shape(migraphx::op::multibroadcast{lens}, input);
}
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}};
throws_shape(migraphx::op::multibroadcast{lens}, input);
}
} }
TEST_CASE(broadcast) TEST_CASE(broadcast)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment