Commit 1f304aed authored by Scott Thornton's avatar Scott Thornton
Browse files

Fixed faulty add compute_shape when using multibroadcast

parent 6c42bc6e
...@@ -762,39 +762,34 @@ struct broadcast ...@@ -762,39 +762,34 @@ struct broadcast
struct multibroadcast struct multibroadcast
{ {
std::vector<std::size_t> output_lens; std::vector<std::size_t> output_lens;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.output_lens, "output_lens"));
}
std::string name() const { return "multibroadcast"; } std::string name() const { return "multibroadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
auto input = inputs.at(0); auto input = inputs.at(0);
if(input.lens().size() <= 0) if (input.lens().size() <= 0)
MIGRAPH_THROW("inputs dimensions should be > 0"); MIGRAPH_THROW("inputs dimensions should be > 0");
if(input.lens().size() > output_lens.size()) if (input.lens().size() > output_lens.size())
MIGRAPH_THROW("inputs dimensions should <= output size"); MIGRAPH_THROW("inputs dimensions should <= output size");
std::vector<size_t> bcast_strides(output_lens.size(), 0); std::vector<size_t> bcast_strides(output_lens.size(), 0);
auto offset = output_lens.size() - input.lens().size(); auto offset = output_lens.size()-input.lens().size();
if(input.lens().size() < output_lens.size()) for (int i = input.lens().size()-1; i >= 0; i--)
{
for(std::size_t i = output_lens.size() - 1; i > 0; i--)
{
if(output_lens[i] == input.lens()[i - offset])
{
bcast_strides[i] = input.strides()[i - offset];
}
}
}
else
{
for(std::size_t i = 0; i < input.lens().size(); i++)
{ {
if(output_lens[i] == input.lens()[i]) if (output_lens[i+offset] == input.lens()[i])
{ {
bcast_strides[i] = input.strides()[i]; bcast_strides[i+offset] = input.strides()[i];
}
} }
} }
return {t, output_lens, bcast_strides}; return {t, output_lens, bcast_strides};
...@@ -833,7 +828,9 @@ struct binary ...@@ -833,7 +828,9 @@ struct binary
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0); auto t = inputs.at(0).type();
auto lens = inputs.at(0).lens();
return {t, lens};
} }
}; };
......
...@@ -107,6 +107,7 @@ struct onnx_parser ...@@ -107,6 +107,7 @@ struct onnx_parser
prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]); prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
return prog.add_instruction(x, args[0], l); return prog.add_instruction(x, args[0], l);
} }
return prog.add_instruction(x, args);
} }
else else
{ {
...@@ -147,8 +148,10 @@ struct onnx_parser ...@@ -147,8 +148,10 @@ struct onnx_parser
output_lens[i + offset] = std::max(s0[i], s1[i + offset]); output_lens[i + offset] = std::max(s0[i], s1[i + offset]);
} }
} }
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, args[0]);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, args[1]);
return prog.add_instruction(x, l0, l1);
} }
return prog.add_instruction(x, args);
}); });
} }
......
...@@ -183,6 +183,13 @@ void multibroadcast_shape() ...@@ -183,6 +183,13 @@ void multibroadcast_shape()
migraph::op::multibroadcast{lens}, migraph::op::multibroadcast{lens},
input); input);
} }
{
std::vector<std::size_t> lens{4, 4, 1, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 3, 3, 1}},
migraph::op::multibroadcast{lens},
input);
}
{ {
std::vector<std::size_t> lens{4, 1, 1, 3}; std::vector<std::size_t> lens{4, 1, 1, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}}; migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment