Commit 03ae8013 authored by Paul's avatar Paul
Browse files

Fix unit test

parent 1e5f7133
...@@ -17,9 +17,7 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n) ...@@ -17,9 +17,7 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
auto blen = s.lens()[n + 1]; auto blen = s.lens()[n + 1];
if(astride == bstride * blen or alen == 1) if(astride == bstride * blen or alen == 1)
{
new_lens.push_back(alen * blen); new_lens.push_back(alen * blen);
}
} }
if(new_lens.size() != shapes.size()) if(new_lens.size() != shapes.size())
return false; return false;
...@@ -37,10 +35,25 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n) ...@@ -37,10 +35,25 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
return true; return true;
} }
void reduce_dim1(std::vector<shape>& shapes)
{
if (std::any_of(shapes.begin(), shapes.end(), [&](const auto& s) {
return s.lens().back() != 1;
}))
return;
for(auto& s : shapes)
{
auto lens = s.lens();
auto strides = s.strides();
lens.pop_back();
strides.pop_back();
s = shape{s.type(), lens, strides};
}
}
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n) std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
{ {
while(reduce_dim(shapes, n) and n < shapes.size()) {} while(reduce_dim(shapes, n) and n < shapes.size()) {}
return n + 1; return n + 1;
} }
void reduce_dim_all(std::vector<shape>& shapes) void reduce_dim_all(std::vector<shape>& shapes)
...@@ -48,6 +61,7 @@ void reduce_dim_all(std::vector<shape>& shapes) ...@@ -48,6 +61,7 @@ void reduce_dim_all(std::vector<shape>& shapes)
std::size_t n = 0; std::size_t n = 0;
while(n < shapes.front().lens().size() - 1) while(n < shapes.front().lens().size() - 1)
n = reduce_dim_all(shapes, n); n = reduce_dim_all(shapes, n);
reduce_dim1(shapes);
} }
std::vector<std::size_t> base_lens(const std::vector<shape>& shapes) std::vector<std::size_t> base_lens(const std::vector<shape>& shapes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment