Commit 6c42bc6e authored by Scott Thornton's avatar Scott Thornton
Browse files

Formatting

parent 8fce4170
...@@ -769,29 +769,29 @@ struct multibroadcast ...@@ -769,29 +769,29 @@ struct multibroadcast
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
auto input = inputs.at(0); auto input = inputs.at(0);
if (input.lens().size() <= 0) if(input.lens().size() <= 0)
MIGRAPH_THROW("inputs dimensions should be > 0"); MIGRAPH_THROW("inputs dimensions should be > 0");
if (input.lens().size() > output_lens.size()) if(input.lens().size() > output_lens.size())
MIGRAPH_THROW("inputs dimensions should <= output size"); MIGRAPH_THROW("inputs dimensions should <= output size");
std::vector<size_t> bcast_strides(output_lens.size(), 0); std::vector<size_t> bcast_strides(output_lens.size(), 0);
auto offset = output_lens.size()-input.lens().size(); auto offset = output_lens.size() - input.lens().size();
if (input.lens().size() < output_lens.size()) if(input.lens().size() < output_lens.size())
{ {
for (std::size_t i = output_lens.size()-1; i > 0; i--) for(std::size_t i = output_lens.size() - 1; i > 0; i--)
{ {
if (output_lens[i] == input.lens()[i-offset]) if(output_lens[i] == input.lens()[i - offset])
{ {
bcast_strides[i] = input.strides()[i-offset]; bcast_strides[i] = input.strides()[i - offset];
} }
} }
} }
else else
{ {
for (std::size_t i = 0; i < input.lens().size(); i++) for(std::size_t i = 0; i < input.lens().size(); i++)
{ {
if (output_lens[i] == input.lens()[i]) if(output_lens[i] == input.lens()[i])
{ {
bcast_strides[i] = input.strides()[i]; bcast_strides[i] = input.strides()[i];
} }
......
...@@ -93,7 +93,8 @@ struct onnx_parser ...@@ -93,7 +93,8 @@ struct onnx_parser
void add_broadcastable_binary_op(std::string name, T x) void add_broadcastable_binary_op(std::string name, T x)
{ {
ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) { ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
if (args.size() != 2) MIGRAPH_THROW("binaGry operators should have 2 operands"); if(args.size() != 2)
MIGRAPH_THROW("binaGry operators should have 2 operands");
if(contains(attributes, "broadcast")) if(contains(attributes, "broadcast"))
{ {
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>(); uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
...@@ -127,33 +128,30 @@ struct onnx_parser ...@@ -127,33 +128,30 @@ struct onnx_parser
const std::vector<std::size_t>& s1 = args[1]->get_shape().lens(); const std::vector<std::size_t>& s1 = args[1]->get_shape().lens();
// Copy the larger vector to output_lens // Copy the larger vector to output_lens
std::vector<std::size_t> output_lens = std::vector<std::size_t> output_lens = (s0.size() >= s1.size()) ? s0 : s1;
(s0.size() >= s1.size()) ? s0 : s1; if(s0.size() >= s1.size())
if (s0.size() >= s1.size())
{ {
// s0 is bigger, so iterate over the range of s1 // s0 is bigger, so iterate over the range of s1
auto offset = s0.size() - s1.size(); auto offset = s0.size() - s1.size();
for (std::size_t i = 0; i < s1.size(); i++) for(std::size_t i = 0; i < s1.size(); i++)
{ {
output_lens[i+offset] = std::max(s0[i+offset], s1[i]); output_lens[i + offset] = std::max(s0[i + offset], s1[i]);
} }
} }
else else
{ {
// s1 is bigger, so iterate over the range of s0 // s1 is bigger, so iterate over the range of s0
auto offset = s1.size() - s0.size(); auto offset = s1.size() - s0.size();
for (std::size_t i = 0; i < s0.size(); i++) for(std::size_t i = 0; i < s0.size(); i++)
{ {
output_lens[i+offset] = std::max(s0[i], s1[i+offset]); output_lens[i + offset] = std::max(s0[i], s1[i + offset]);
} }
} }
} }
return prog.add_instruction(x, args); return prog.add_instruction(x, args);
}); });
} }
template <class T> template <class T>
void add_generic_op(std::string name, T x) void add_generic_op(std::string name, T x)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment