Commit 1dca597f authored by Scott Thornton's avatar Scott Thornton
Browse files

Formatting

parent 8e4b1022
...@@ -622,13 +622,14 @@ struct broadcast ...@@ -622,13 +622,14 @@ struct broadcast
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
auto input = inputs.at(0); auto input = inputs.at(0);
std::vector<size_t> bcast_strides(broadcast_shape.lens().size(), 0); std::vector<size_t> bcast_strides(broadcast_shape.lens().size(), 0);
if(std::all_of( if(std::all_of(broadcast_shape.lens().cbegin(), broadcast_shape.lens().cend(), [&](auto x) {
broadcast_shape.lens().cbegin(), broadcast_shape.lens().cend(), [&](auto x) { return x == 1; })) return x == 1;
}))
{ {
if(axis != 0) if(axis != 0)
MIGRAPH_THROW("when broadcasting tensor of size 1, axis should be 0"); MIGRAPH_THROW("when broadcasting tensor of size 1, axis should be 0");
...@@ -637,7 +638,8 @@ struct broadcast ...@@ -637,7 +638,8 @@ struct broadcast
else else
{ {
assert(broadcast_shape.lens().size() - axis >= input.lens().size()); assert(broadcast_shape.lens().size() - axis >= input.lens().size());
if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_shape.lens().begin() + axis)) if(!std::equal(
input.lens().begin(), input.lens().end(), broadcast_shape.lens().begin() + axis))
MIGRAPH_THROW("when broadcasting success sizes must match"); MIGRAPH_THROW("when broadcasting success sizes must match");
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis); std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, broadcast_shape.lens(), std::move(bcast_strides)}; return {t, broadcast_shape.lens(), std::move(bcast_strides)};
......
...@@ -93,7 +93,8 @@ struct onnx_parser ...@@ -93,7 +93,8 @@ struct onnx_parser
uint64_t axis = (contains(attributes, "axis")) uint64_t axis = (contains(attributes, "axis"))
? parse_value(attributes.at("axis")).at<uint64_t>() ? parse_value(attributes.at("axis")).at<uint64_t>()
: 0; : 0;
auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]); auto l =
prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
return prog.add_instruction(x, args[0], l); return prog.add_instruction(x, args[0], l);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment