Commit 6c42bc6e authored by Scott Thornton's avatar Scott Thornton
Browse files

Formatting

parent 8fce4170
......@@ -769,29 +769,29 @@ struct multibroadcast
auto t = inputs.at(0).type();
auto input = inputs.at(0);
if (input.lens().size() <= 0)
if(input.lens().size() <= 0)
MIGRAPH_THROW("inputs dimensions should be > 0");
if (input.lens().size() > output_lens.size())
if(input.lens().size() > output_lens.size())
MIGRAPH_THROW("inputs dimensions should <= output size");
std::vector<size_t> bcast_strides(output_lens.size(), 0);
auto offset = output_lens.size()-input.lens().size();
if (input.lens().size() < output_lens.size())
auto offset = output_lens.size() - input.lens().size();
if(input.lens().size() < output_lens.size())
{
for (std::size_t i = output_lens.size()-1; i > 0; i--)
for(std::size_t i = output_lens.size() - 1; i > 0; i--)
{
if (output_lens[i] == input.lens()[i-offset])
if(output_lens[i] == input.lens()[i - offset])
{
bcast_strides[i] = input.strides()[i-offset];
bcast_strides[i] = input.strides()[i - offset];
}
}
}
else
{
for (std::size_t i = 0; i < input.lens().size(); i++)
for(std::size_t i = 0; i < input.lens().size(); i++)
{
if (output_lens[i] == input.lens()[i])
if(output_lens[i] == input.lens()[i])
{
bcast_strides[i] = input.strides()[i];
}
......
......@@ -93,7 +93,8 @@ struct onnx_parser
void add_broadcastable_binary_op(std::string name, T x)
{
ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
if (args.size() != 2) MIGRAPH_THROW("binaGry operators should have 2 operands");
if(args.size() != 2)
MIGRAPH_THROW("binaGry operators should have 2 operands");
if(contains(attributes, "broadcast"))
{
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
......@@ -112,8 +113,8 @@ struct onnx_parser
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
......@@ -127,33 +128,30 @@ struct onnx_parser
const std::vector<std::size_t>& s1 = args[1]->get_shape().lens();
// Copy the larger vector to output_lens
std::vector<std::size_t> output_lens =
(s0.size() >= s1.size()) ? s0 : s1;
if (s0.size() >= s1.size())
std::vector<std::size_t> output_lens = (s0.size() >= s1.size()) ? s0 : s1;
if(s0.size() >= s1.size())
{
// s0 is bigger, so iterate over the range of s1
auto offset = s0.size() - s1.size();
for (std::size_t i = 0; i < s1.size(); i++)
for(std::size_t i = 0; i < s1.size(); i++)
{
output_lens[i+offset] = std::max(s0[i+offset], s1[i]);
output_lens[i + offset] = std::max(s0[i + offset], s1[i]);
}
}
else
{
// s1 is bigger, so iterate over the range of s0
auto offset = s1.size() - s0.size();
for (std::size_t i = 0; i < s0.size(); i++)
for(std::size_t i = 0; i < s0.size(); i++)
{
output_lens[i+offset] = std::max(s0[i], s1[i+offset]);
output_lens[i + offset] = std::max(s0[i], s1[i + offset]);
}
}
}
return prog.add_instruction(x, args);
});
}
template <class T>
void add_generic_op(std::string name, T x)
{
......
......@@ -512,10 +512,10 @@ void add_broadcast_test()
std::vector<float> a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
migraph::shape b_shape{migraph::shape::float_type, {2, 2, 1}};
std::vector<float> b_data{0, -1, -2, -3};
auto l1 = p.add_literal(migraph::literal{a_shape, a_data});
auto l2 = p.add_literal(migraph::literal{b_shape, b_data});
auto l3 = p.add_instruction(migraph::op::multibroadcast{{2, 2, 3}}, l1);
auto l4 = p.add_instruction(migraph::op::multibroadcast{{2, 2, 3}}, l2);
auto l1 = p.add_literal(migraph::literal{a_shape, a_data});
auto l2 = p.add_literal(migraph::literal{b_shape, b_data});
auto l3 = p.add_instruction(migraph::op::multibroadcast{{2, 2, 3}}, l1);
auto l4 = p.add_instruction(migraph::op::multibroadcast{{2, 2, 3}}, l2);
p.add_instruction(migraph::op::add{}, l3, l4);
p.compile(migraph::cpu::target{});
auto result = p.eval({});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment