Commit 03ae8013 authored by Paul's avatar Paul
Browse files

Fix unit test

parent 1e5f7133
......@@ -17,10 +17,8 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
auto blen = s.lens()[n + 1];
if(astride == bstride * blen or alen == 1)
{
new_lens.push_back(alen * blen);
}
}
if(new_lens.size() != shapes.size())
return false;
std::size_t i = 0;
......@@ -37,10 +35,25 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
return true;
}
void reduce_dim1(std::vector<shape>& shapes)
{
if (std::any_of(shapes.begin(), shapes.end(), [&](const auto& s) {
return s.lens().back() != 1;
}))
return;
for(auto& s : shapes)
{
auto lens = s.lens();
auto strides = s.strides();
lens.pop_back();
strides.pop_back();
s = shape{s.type(), lens, strides};
}
}
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
{
while(reduce_dim(shapes, n) and n < shapes.size()) {}
return n + 1;
}
void reduce_dim_all(std::vector<shape>& shapes)
......@@ -48,6 +61,7 @@ void reduce_dim_all(std::vector<shape>& shapes)
std::size_t n = 0;
while(n < shapes.front().lens().size() - 1)
n = reduce_dim_all(shapes, n);
reduce_dim1(shapes);
}
std::vector<std::size_t> base_lens(const std::vector<shape>& shapes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment