Commit 80438de8 authored by Paul's avatar Paul
Browse files

Cleanup code

parent 571fc2cd
......@@ -129,29 +129,18 @@ struct onnx_parser
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const std::vector<std::size_t>& s0 = args[0]->get_shape().lens();
const std::vector<std::size_t>& s1 = args[1]->get_shape().lens();
const std::vector<std::size_t>* s0 = &args[0]->get_shape().lens();
const std::vector<std::size_t>* s1 = &args[1]->get_shape().lens();
// Make sure s0 is the smaller size
if(s0->size() > s1->size())
std::swap(s0, s1);
// Copy the larger vector to output_lens
std::vector<std::size_t> output_lens = (s0.size() >= s1.size()) ? s0 : s1;
if(s0.size() >= s1.size())
{
// s0 is bigger, so iterate over the range of s1
auto offset = s0.size() - s1.size();
for(std::size_t i = 0; i < s1.size(); i++)
{
output_lens[i + offset] = std::max(s0[i + offset], s1[i]);
}
}
else
{
// s1 is bigger, so iterate over the range of s0
auto offset = s1.size() - s0.size();
for(std::size_t i = 0; i < s0.size(); i++)
{
output_lens[i + offset] = std::max(s0[i], s1[i + offset]);
}
}
std::vector<std::size_t> output_lens(s1->size());
auto offset = s1->size() - s0->size();
std::transform(s0->begin(), s0->end(), s1->begin() + offset, output_lens.begin() + offset, [](auto a, auto b) { return std::max(a, b); });
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, args[0]);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, args[1]);
return prog.add_instruction(x, l0, l1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment