Commit e5723db5 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into argmax_min

parents ab93262e 3db703df
......@@ -35,14 +35,28 @@ struct multibroadcast
auto input = inputs.at(0);
if(input.lens().empty())
MIGRAPHX_THROW("inputs dimensions should be > 0");
{
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0");
}
if(input.lens().size() > output_lens.size())
MIGRAPHX_THROW("inputs dimensions should <= output size");
{
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should <= output size");
}
std::vector<size_t> bcast_strides(output_lens.size(), 0);
auto offset = output_lens.size() - input.lens().size();
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--)
{
if(output_lens[i + offset] != input.lens()[i] and input.lens()[i] != 1)
{
MIGRAPHX_THROW("MULTIBROADCAST: input shape {" + to_string_range(input.lens()) +
"} cannot be broadcasted to {" + to_string_range(output_lens) +
"}!");
}
}
std::vector<size_t> bcast_strides(output_lens.size(), 0);
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--)
{
if(output_lens[i + offset] == input.lens()[i])
{
......
......@@ -185,7 +185,15 @@ struct onnx_parser
s0.end(),
s1.begin() + offset,
out_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
[&](auto a, auto b) {
if(a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" +
to_string_range(s0) + "} and {" +
to_string_range(s1) + "} mismatch!");
}
return std::max(a, b);
});
return out_lens;
}
......
implicit_bcast-example:q

add2:u

0
12"Addtest-multi_bcastZ
1out"Add subtraction2Z
0




Z
1

Z
1


b
2

b
out




B
\ No newline at end of file
B
 subtraction2:q
add2:q

0
1out"Sub subtraction2Z
......@@ -10,11 +10,11 @@
Z
1


b

b
out




B
\ No newline at end of file
B
......@@ -350,7 +350,7 @@ TEST_CASE(implicit_add_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, l2, l3);
......@@ -377,7 +377,7 @@ TEST_CASE(implicit_sub_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::sub{}, l2, l3);
......
......@@ -227,6 +227,16 @@ TEST_CASE(multibroadcast)
migraphx::shape input{migraphx::shape::float_type, {}};
throws_shape(migraphx::op::multibroadcast{lens}, input);
}
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {3, 4}};
throws_shape(migraphx::op::multibroadcast{lens}, input);
}
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}};
throws_shape(migraphx::op::multibroadcast{lens}, input);
}
}
TEST_CASE(broadcast)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment