Commit 1f304aed authored by Scott Thornton's avatar Scott Thornton
Browse files

Fixed faulty add compute_shape when using multibroadcast

parent 6c42bc6e
......@@ -762,39 +762,34 @@ struct broadcast
struct multibroadcast
{
std::vector<std::size_t> output_lens;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.output_lens, "output_lens"));
}
std::string name() const { return "multibroadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto t = inputs.at(0).type();
auto input = inputs.at(0);
if(input.lens().size() <= 0)
if (input.lens().size() <= 0)
MIGRAPH_THROW("inputs dimensions should be > 0");
if(input.lens().size() > output_lens.size())
if (input.lens().size() > output_lens.size())
MIGRAPH_THROW("inputs dimensions should <= output size");
std::vector<size_t> bcast_strides(output_lens.size(), 0);
auto offset = output_lens.size() - input.lens().size();
if(input.lens().size() < output_lens.size())
auto offset = output_lens.size()-input.lens().size();
for (int i = input.lens().size()-1; i >= 0; i--)
{
for(std::size_t i = output_lens.size() - 1; i > 0; i--)
if (output_lens[i+offset] == input.lens()[i])
{
if(output_lens[i] == input.lens()[i - offset])
{
bcast_strides[i] = input.strides()[i - offset];
}
}
}
else
{
for(std::size_t i = 0; i < input.lens().size(); i++)
{
if(output_lens[i] == input.lens()[i])
{
bcast_strides[i] = input.strides()[i];
}
bcast_strides[i+offset] = input.strides()[i];
}
}
return {t, output_lens, bcast_strides};
......@@ -833,7 +828,9 @@ struct binary
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0);
auto t = inputs.at(0).type();
auto lens = inputs.at(0).lens();
return {t, lens};
}
};
......
......@@ -107,6 +107,7 @@ struct onnx_parser
prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
return prog.add_instruction(x, args[0], l);
}
return prog.add_instruction(x, args);
}
else
{
......@@ -147,8 +148,10 @@ struct onnx_parser
output_lens[i + offset] = std::max(s0[i], s1[i + offset]);
}
}
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, args[0]);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, args[1]);
return prog.add_instruction(x, l0, l1);
}
return prog.add_instruction(x, args);
});
}
......
......@@ -183,6 +183,13 @@ void multibroadcast_shape()
migraph::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 4, 1, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 3, 3, 1}},
migraph::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 1, 1, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment