Commit 1dca597f authored by Scott Thornton's avatar Scott Thornton
Browse files

Formatting

parent 8e4b1022
......@@ -622,13 +622,14 @@ struct broadcast
std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto t = inputs.at(0).type();
auto input = inputs.at(0);
auto t = inputs.at(0).type();
auto input = inputs.at(0);
std::vector<size_t> bcast_strides(broadcast_shape.lens().size(), 0);
if(std::all_of(
broadcast_shape.lens().cbegin(), broadcast_shape.lens().cend(), [&](auto x) { return x == 1; }))
if(std::all_of(broadcast_shape.lens().cbegin(), broadcast_shape.lens().cend(), [&](auto x) {
return x == 1;
}))
{
if(axis != 0)
MIGRAPH_THROW("when broadcasting tensor of size 1, axis should be 0");
......@@ -637,7 +638,8 @@ struct broadcast
else
{
assert(broadcast_shape.lens().size() - axis >= input.lens().size());
if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_shape.lens().begin() + axis))
if(!std::equal(
input.lens().begin(), input.lens().end(), broadcast_shape.lens().begin() + axis))
MIGRAPH_THROW("when broadcasting success sizes must match");
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, broadcast_shape.lens(), std::move(bcast_strides)};
......
......@@ -93,7 +93,8 @@ struct onnx_parser
uint64_t axis = (contains(attributes, "axis"))
? parse_value(attributes.at("axis")).at<uint64_t>()
: 0;
auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
auto l =
prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
return prog.add_instruction(x, args[0], l);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment