Commit 8724a471 authored by Paul's avatar Paul
Browse files

Fix indexing in the shape class

parent 17c6d683
......@@ -116,15 +116,20 @@ std::size_t shape::index(std::size_t i) const
if(this->standard())
return i;
else
return std::inner_product(this->lens().begin(),
this->lens().end(),
this->strides().begin(),
std::size_t{0},
std::plus<std::size_t>{},
[&](std::size_t len, std::size_t stride) {
assert(stride > 0 and len > 0);
return ((i / stride) % len) * stride;
});
{
std::size_t s = 1;
std::size_t result = 0;
for(std::size_t j = 0;j < this->lens().size();j++)
{
const std::size_t k = this->lens().size() - j - 1;
const std::size_t stride = this->strides()[k];
const std::size_t len = this->lens()[k];
const std::size_t idx = (i % (s * len)) / s;
result += stride * idx;
s *= len;
}
return result;
}
}
bool shape::packed() const { return this->elements() == this->element_space(); }
......
......@@ -97,6 +97,72 @@ void test_shape4()
EXPECT(s.index(s.elements() - 1) == s.elements() - 1);
}
void test_shape42()
{
migraph::shape s{migraph::shape::float_type, {100, 32, 8, 8}, {2048, 64, 8, 1}};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
EXPECT(s.type() == migraph::shape::float_type);
EXPECT(s.lens()[0] == 100);
EXPECT(s.lens()[1] == 32);
EXPECT(s.lens()[2] == 8);
EXPECT(s.lens()[3] == 8);
EXPECT(s.strides()[0] == s.lens()[1] * s.strides()[1]);
EXPECT(s.strides()[1] == s.lens()[2] * s.strides()[2]);
EXPECT(s.strides()[2] == s.lens()[3] * s.strides()[3]);
EXPECT(s.strides()[3] == 1);
EXPECT(s.elements() == 100 * 32 * 8 * 8);
EXPECT(s.bytes() == 100 * 32 * 8 * 8 * sizeof(float));
EXPECT(s.index({0, 0, 0, 0}) == 0);
EXPECT(s.index({0, 0, 0, 1}) == 1);
EXPECT(s.index({0, 0, 0, 0}) == s.index(0));
EXPECT(s.index({0, 0, 0, 1}) == s.index(1));
EXPECT(s.index({0, 0, 1, 0}) == s.index(8));
EXPECT(s.index({0, 1, 0, 0}) == s.index(8 * 8));
EXPECT(s.index({1, 0, 0, 0}) == s.index(8 * 8 * 32));
EXPECT(s.index(0) == 0);
EXPECT(s.index(1) == 1);
EXPECT(s.index(8) == 8);
EXPECT(s.index(8 * 8) == 8 * 8);
EXPECT(s.index(8 * 8 * 32) == 8 * 8 * 32);
EXPECT(s.index(s.elements() - 1) == s.elements() - 1);
}
void test_shape4_transposed()
{
migraph::shape s{migraph::shape::float_type, {32, 100, 8, 8}, {64, 2048, 8, 1}};
EXPECT(s.transposed());
EXPECT(s.packed());
EXPECT(not s.standard());
EXPECT(not s.broadcasted());
EXPECT(s.type() == migraph::shape::float_type);
EXPECT(s.lens()[0] == 32);
EXPECT(s.lens()[1] == 100);
EXPECT(s.lens()[2] == 8);
EXPECT(s.lens()[3] == 8);
EXPECT(s.strides()[0] == 64);
EXPECT(s.strides()[1] == 2048);
EXPECT(s.strides()[2] == 8);
EXPECT(s.strides()[3] == 1);
EXPECT(s.elements() == 100 * 32 * 8 * 8);
EXPECT(s.bytes() == 100 * 32 * 8 * 8 * sizeof(float));
EXPECT(s.index({0, 0, 0, 0}) == 0);
EXPECT(s.index({0, 0, 0, 1}) == 1);
EXPECT(s.index({0, 0, 0, 0}) == s.index(0));
EXPECT(s.index({0, 0, 0, 1}) == s.index(1));
EXPECT(s.index({0, 0, 1, 0}) == s.index(8));
EXPECT(s.index({0, 1, 0, 0}) == s.index(8 * 8));
EXPECT(s.index({1, 0, 0, 0}) == s.index(8 * 8 * 100));
EXPECT(s.index(0) == 0);
EXPECT(s.index(1) == 1);
EXPECT(s.index(8) == 8);
EXPECT(s.index(8 * 8) == 2048);
EXPECT(s.index(8 * 8 * 100) == 64);
EXPECT(s.index(s.elements() - 1) == s.elements() - 1);
}
void test_shape4_nonpacked()
{
std::vector<std::size_t> lens = {100, 32, 8, 8};
......@@ -134,11 +200,10 @@ void test_shape4_nonpacked()
EXPECT(s.index(1) == 1);
EXPECT(s.index({0, 0, 0, 0}) == 0);
EXPECT(s.index({0, 0, 0, 1}) == s.index(1));
// TODO: Fix these tests
// EXPECT(s.index({0, 0, 1, 0}) == s.index(8));
// EXPECT(s.index({0, 1, 0, 0}) == s.index(8 * 8));
// EXPECT(s.index({1, 0, 0, 0}) == s.index(8 * 8 * 32));
// EXPECT(s.index(s.elements() - 1) == 469273);
EXPECT(s.index({0, 0, 1, 0}) == s.index(8));
EXPECT(s.index({0, 1, 0, 0}) == s.index(8 * 8));
EXPECT(s.index({1, 0, 0, 0}) == s.index(8 * 8 * 32));
EXPECT(s.index(s.elements() - 1) == 469273);
}
int main()
......@@ -151,5 +216,7 @@ int main()
test_shape_broadcasted();
test_shape_default_copy();
test_shape4();
test_shape42();
test_shape4_transposed();
test_shape4_nonpacked();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment