Unverified Commit b5fcc0bc authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into blas_tuning

parents e234f631 71c8181c
......@@ -222,11 +222,15 @@ struct shape
/// Map element index to space index
std::size_t index(std::size_t i) const;
std::vector<std::size_t> multi(std::size_t i) const;
void multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const;
/// Map element index to multi-dimensional index
std::vector<std::size_t> multi(std::size_t idx) const;
/// Returns true if the shape is packed (number of elements and buffer size the same) with no
/// padding
/// Map element index to multi-dimensional index and put them them into location provided by
/// pointers
void multi_copy(std::size_t idx, std::size_t* start, const std::size_t* end) const;
/// Returns true if the shape is packed (number of elements and buffer size the same) with
/// no padding
bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending
......
......@@ -361,29 +361,26 @@ std::size_t shape::index(std::size_t i) const
}
}
std::vector<std::size_t> shape::multi(std::size_t i) const
std::vector<std::size_t> shape::multi(std::size_t idx) const
{
assert(this->standard());
assert(idx < elements());
std::vector<std::size_t> indices(lens().size());
multi_copy(i, indices.data(), indices.data() + lens().size());
multi_copy(idx, indices.data(), indices.data() + lens().size());
return indices;
}
void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const
void shape::multi_copy(std::size_t idx, std::size_t* start, const std::size_t* end) const
{
assert(this->standard());
size_t tidx = idx;
(void)end;
assert(idx < elements());
assert(lens().size() <= (end - start));
std::transform(strides().begin(),
strides().end(),
lens().begin(),
start,
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (i / stride) % len;
});
for(size_t ii = lens().size() - 1; ii > 0; ii--)
{
*(start + ii) = tidx % lens()[ii];
tidx = tidx / lens()[ii];
}
*start = tidx;
}
bool shape::packed() const
......
......@@ -30,6 +30,7 @@
#include <array>
#include <algorithm>
#include <numeric>
#include <migraphx/verify.hpp>
#include "test.hpp"
TEST_CASE(test_shape_default)
......@@ -929,4 +930,16 @@ TEST_CASE(test_with_type)
EXPECT(s.strides() == new_s.strides());
}
TEST_CASE(test_multi_index)
{
migraphx::shape s{migraphx::shape::float_type, {2, 4, 6}};
EXPECT(migraphx::verify_range(s.multi(0), std::vector<size_t>{0, 0, 0}));
EXPECT(migraphx::verify_range(s.multi(4), std::vector<size_t>{0, 0, 4}));
EXPECT(migraphx::verify_range(s.multi(6), std::vector<size_t>{0, 1, 0}));
EXPECT(migraphx::verify_range(s.multi(8), std::vector<size_t>{0, 1, 2}));
EXPECT(migraphx::verify_range(s.multi(24), std::vector<size_t>{1, 0, 0}));
EXPECT(migraphx::verify_range(s.multi(30), std::vector<size_t>{1, 1, 0}));
EXPECT(migraphx::verify_range(s.multi(34), std::vector<size_t>{1, 1, 4}));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment