Commit 15f23855 authored by Khalique's avatar Khalique
Browse files

formatting

parent 21ec113b
......@@ -101,7 +101,7 @@ struct onnx_parser
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
});
}
template <class T>
void add_binary_op(std::string name, T x)
{
......@@ -134,19 +134,19 @@ struct onnx_parser
{
if(arg0->get_shape() != arg1->get_shape())
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const std::vector<std::size_t>* s0 = &arg0->get_shape().lens();
const std::vector<std::size_t>* s1 = &arg1->get_shape().lens();
......@@ -158,10 +158,10 @@ struct onnx_parser
std::vector<std::size_t> output_lens(s1->size());
auto offset = s1->size() - s0->size();
std::transform(s0->begin(),
s0->end(),
s1->begin() + offset,
output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
s0->end(),
s1->begin() + offset,
output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1);
......@@ -185,9 +185,9 @@ struct onnx_parser
parse_sum(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto curr_sum = args.front();
if (args.size() > 1)
if(args.size() > 1)
{
for (auto it = std::next(args.begin()); it != args.end(); ++it)
for(auto it = std::next(args.begin()); it != args.end(); ++it)
{
curr_sum = add_broadcastable_binary_op(curr_sum, *it, op::add{});
}
......@@ -199,9 +199,9 @@ struct onnx_parser
parse_max(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto curr_max = args.front();
if (args.size() > 1)
if(args.size() > 1)
{
for (auto it = std::next(args.begin()); it != args.end(); ++it)
for(auto it = std::next(args.begin()); it != args.end(); ++it)
{
curr_max = add_broadcastable_binary_op(curr_max, *it, op::max{});
}
......@@ -213,9 +213,9 @@ struct onnx_parser
parse_min(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto curr_min = args.front();
if (args.size() > 1)
if(args.size() > 1)
{
for (auto it = std::next(args.begin()); it != args.end(); ++it)
for(auto it = std::next(args.begin()); it != args.end(); ++it)
{
curr_min = add_broadcastable_binary_op(curr_min, *it, op::min{});
}
......
......@@ -582,7 +582,6 @@ struct min_op
}
};
template <typename Op>
struct cpu_binary
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment