"src/vscode:/vscode.git/clone" did not exist on "d70abafab21b41605740b2661d606c517320d044"
Commit dc945bf8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 1ff459b4
...@@ -584,7 +584,7 @@ struct squeeze ...@@ -584,7 +584,7 @@ struct squeeze
} }
// squeezing a single element generates a scalar // squeezing a single element generates a scalar
if (new_lens.empty()) if(new_lens.empty())
{ {
return {type}; return {type};
} }
......
...@@ -154,7 +154,8 @@ struct onnx_parser ...@@ -154,7 +154,8 @@ struct onnx_parser
}); });
} }
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<std::size_t> s1) std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1)
{ {
// Example: // Example:
// s0 = (3,2,4,5) and s1 = (2,1,1) // s0 = (3,2,4,5) and s1 = (2,1,1)
...@@ -168,14 +169,17 @@ struct onnx_parser ...@@ -168,14 +169,17 @@ struct onnx_parser
// In this case we need to broadcast the (:,:,1:,:) axis // In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving // of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5) // output_lens = (3,2,7,5)
if (s0.size() > s1.size()) if(s0.size() > s1.size())
{ {
s0.swap(s1); s0.swap(s1);
} }
std::vector<std::size_t> out_lens(s1); std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size(); auto offset = s1.size() - s0.size();
std::transform(s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, std::transform(s0.begin(),
s0.end(),
s1.begin() + offset,
out_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); }); [](auto a, auto b) { return std::max(a, b); });
return out_lens; return out_lens;
...@@ -503,7 +507,7 @@ struct onnx_parser ...@@ -503,7 +507,7 @@ struct onnx_parser
out_lens.back() = l2->get_shape().lens().back(); out_lens.back() = l2->get_shape().lens().back();
auto l3 = args[2]; auto l3 = args[2];
auto l3_lens = l3->get_shape().lens(); auto l3_lens = l3->get_shape().lens();
if (!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end())) if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{ {
l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]); l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
} }
...@@ -524,7 +528,7 @@ struct onnx_parser ...@@ -524,7 +528,7 @@ struct onnx_parser
// args[0] is a vector, prepend 1 to the shape // args[0] is a vector, prepend 1 to the shape
bool is_a_prepended = false; bool is_a_prepended = false;
if (l0_lens.size() == 1) if(l0_lens.size() == 1)
{ {
is_a_prepended = true; is_a_prepended = true;
l0_lens.insert(l0_lens.begin(), 1); l0_lens.insert(l0_lens.begin(), 1);
...@@ -532,7 +536,7 @@ struct onnx_parser ...@@ -532,7 +536,7 @@ struct onnx_parser
} }
bool is_b_appended = false; bool is_b_appended = false;
if (l1_lens.size() == 1) if(l1_lens.size() == 1)
{ {
is_b_appended = true; is_b_appended = true;
l1_lens.push_back(1); l1_lens.push_back(1);
...@@ -541,7 +545,7 @@ struct onnx_parser ...@@ -541,7 +545,7 @@ struct onnx_parser
instruction_ref bl0 = l0; instruction_ref bl0 = l0;
instruction_ref bl1 = l1; instruction_ref bl1 = l1;
if (!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend())) if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
{ {
auto l0_it = l0_lens.begin() + l0_lens.size() - 2; auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it); std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
...@@ -550,11 +554,11 @@ struct onnx_parser ...@@ -550,11 +554,11 @@ struct onnx_parser
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens); auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end()); l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end()); l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
if (l0_lens != l0_broadcasted_lens) if(l0_lens != l0_broadcasted_lens)
{ {
bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0); bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
} }
if (l1_lens != l1_broadcasted_lens) if(l1_lens != l1_broadcasted_lens)
{ {
bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1); bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
} }
...@@ -562,12 +566,12 @@ struct onnx_parser ...@@ -562,12 +566,12 @@ struct onnx_parser
auto dot_res = prog.add_instruction(op::dot{1.0f, 0.0f}, bl0, bl1); auto dot_res = prog.add_instruction(op::dot{1.0f, 0.0f}, bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size()); int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if (is_a_prepended) if(is_a_prepended)
{ {
dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res); dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
--num_axis; --num_axis;
} }
if (is_b_appended) if(is_b_appended)
{ {
dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res); dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment