Commit 26277ec5 authored by Paul's avatar Paul
Browse files

Fix accuracy bug when vectorizing slices

parent ed7973d1
...@@ -61,12 +61,17 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs ...@@ -61,12 +61,17 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
[&](const auto& input) -> std::size_t { [&](const auto& input) -> std::size_t {
auto stride = input.strides()[axis]; auto stride = input.strides()[axis];
auto len = input.lens()[axis]; auto len = input.lens()[axis];
if(stride != 0 and stride != 1) if(not contains({0, 1}, stride))
return 1; return 1;
if(len == 1 and input.elements() > sizes.front()) if(len == 1 and input.elements() > sizes.front())
return sizes.front(); return sizes.front();
auto it = std::find_if( auto it = std::find_if(
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; }); sizes.begin(), sizes.end(), [&](auto vsize) {
// The len is divisible by the size and all the strides are divisible by the size
return (len % vsize) == 0 and std::all_of(input.strides().begin(), input.strides().end(), [&](auto i) {
return contains({0, 1}, i) or i % vsize == 0;
});
});
if(it != sizes.end()) if(it != sizes.end())
return *it; return *it;
return 1; return 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment