Commit 15f23855 authored by Khalique's avatar Khalique
Browse files

formatting

parent 21ec113b
...@@ -101,7 +101,7 @@ struct onnx_parser ...@@ -101,7 +101,7 @@ struct onnx_parser
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...); return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
}); });
} }
template <class T> template <class T>
void add_binary_op(std::string name, T x) void add_binary_op(std::string name, T x)
{ {
...@@ -134,19 +134,19 @@ struct onnx_parser ...@@ -134,19 +134,19 @@ struct onnx_parser
{ {
if(arg0->get_shape() != arg1->get_shape()) if(arg0->get_shape() != arg1->get_shape())
{ {
// Example: // Example:
// s0 = (3,2,4,5) and s1 = (2,1,1) // s0 = (3,2,4,5) and s1 = (2,1,1)
// //
// In this case we need to broadcast (:,1,1) portion of // In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1 // s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5) // giving output_lens = (3,2,4,5)
// //
// Another example: // Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5) // s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis // In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving // of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5) // output_lens = (3,2,7,5)
// //
// Get lengths for both arguments // Get lengths for both arguments
const std::vector<std::size_t>* s0 = &arg0->get_shape().lens(); const std::vector<std::size_t>* s0 = &arg0->get_shape().lens();
const std::vector<std::size_t>* s1 = &arg1->get_shape().lens(); const std::vector<std::size_t>* s1 = &arg1->get_shape().lens();
...@@ -158,10 +158,10 @@ struct onnx_parser ...@@ -158,10 +158,10 @@ struct onnx_parser
std::vector<std::size_t> output_lens(s1->size()); std::vector<std::size_t> output_lens(s1->size());
auto offset = s1->size() - s0->size(); auto offset = s1->size() - s0->size();
std::transform(s0->begin(), std::transform(s0->begin(),
s0->end(), s0->end(),
s1->begin() + offset, s1->begin() + offset,
output_lens.begin() + offset, output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); }); [](auto a, auto b) { return std::max(a, b); });
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0); auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1); auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1);
...@@ -185,9 +185,9 @@ struct onnx_parser ...@@ -185,9 +185,9 @@ struct onnx_parser
parse_sum(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_sum(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
auto curr_sum = args.front(); auto curr_sum = args.front();
if (args.size() > 1) if(args.size() > 1)
{ {
for (auto it = std::next(args.begin()); it != args.end(); ++it) for(auto it = std::next(args.begin()); it != args.end(); ++it)
{ {
curr_sum = add_broadcastable_binary_op(curr_sum, *it, op::add{}); curr_sum = add_broadcastable_binary_op(curr_sum, *it, op::add{});
} }
...@@ -199,9 +199,9 @@ struct onnx_parser ...@@ -199,9 +199,9 @@ struct onnx_parser
parse_max(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_max(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
auto curr_max = args.front(); auto curr_max = args.front();
if (args.size() > 1) if(args.size() > 1)
{ {
for (auto it = std::next(args.begin()); it != args.end(); ++it) for(auto it = std::next(args.begin()); it != args.end(); ++it)
{ {
curr_max = add_broadcastable_binary_op(curr_max, *it, op::max{}); curr_max = add_broadcastable_binary_op(curr_max, *it, op::max{});
} }
...@@ -213,9 +213,9 @@ struct onnx_parser ...@@ -213,9 +213,9 @@ struct onnx_parser
parse_min(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_min(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
auto curr_min = args.front(); auto curr_min = args.front();
if (args.size() > 1) if(args.size() > 1)
{ {
for (auto it = std::next(args.begin()); it != args.end(); ++it) for(auto it = std::next(args.begin()); it != args.end(); ++it)
{ {
curr_min = add_broadcastable_binary_op(curr_min, *it, op::min{}); curr_min = add_broadcastable_binary_op(curr_min, *it, op::min{});
} }
......
...@@ -582,7 +582,6 @@ struct min_op ...@@ -582,7 +582,6 @@ struct min_op
} }
}; };
template <typename Op> template <typename Op>
struct cpu_binary struct cpu_binary
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment