Commit dc945bf8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 1ff459b4
......@@ -584,7 +584,7 @@ struct squeeze
}
// squeezing a single element generates a scalar
if (new_lens.empty())
if(new_lens.empty())
{
return {type};
}
......
......@@ -154,7 +154,8 @@ struct onnx_parser
});
}
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<std::size_t> s1)
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1)
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
......@@ -168,14 +169,17 @@ struct onnx_parser
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
if (s0.size() > s1.size())
if(s0.size() > s1.size())
{
s0.swap(s1);
}
std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size();
std::transform(s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset,
std::transform(s0.begin(),
s0.end(),
s1.begin() + offset,
out_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
return out_lens;
......@@ -503,7 +507,7 @@ struct onnx_parser
out_lens.back() = l2->get_shape().lens().back();
auto l3 = args[2];
auto l3_lens = l3->get_shape().lens();
if (!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{
l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
}
......@@ -524,7 +528,7 @@ struct onnx_parser
// args[0] is a vector, prepend 1 to the shape
bool is_a_prepended = false;
if (l0_lens.size() == 1)
if(l0_lens.size() == 1)
{
is_a_prepended = true;
l0_lens.insert(l0_lens.begin(), 1);
......@@ -532,7 +536,7 @@ struct onnx_parser
}
bool is_b_appended = false;
if (l1_lens.size() == 1)
if(l1_lens.size() == 1)
{
is_b_appended = true;
l1_lens.push_back(1);
......@@ -541,7 +545,7 @@ struct onnx_parser
instruction_ref bl0 = l0;
instruction_ref bl1 = l1;
if (!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
{
auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
......@@ -550,11 +554,11 @@ struct onnx_parser
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
if (l0_lens != l0_broadcasted_lens)
if(l0_lens != l0_broadcasted_lens)
{
bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
}
if (l1_lens != l1_broadcasted_lens)
if(l1_lens != l1_broadcasted_lens)
{
bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
}
......@@ -562,12 +566,12 @@ struct onnx_parser
auto dot_res = prog.add_instruction(op::dot{1.0f, 0.0f}, bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if (is_a_prepended)
if(is_a_prepended)
{
dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
--num_axis;
}
if (is_b_appended)
if(is_b_appended)
{
dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment