shape_test.cpp 3.59 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
#include <migraph/shape.hpp>
Paul's avatar
Paul committed
3
4
5
#include <array>
#include <algorithm>
#include <numeric>
Paul's avatar
Paul committed
6
7
8
9
#include "test.hpp"

void test_shape_assign()
{
Paul's avatar
Paul committed
10
11
    migraph::shape s1{migraph::shape::float_type, {100, 32, 8, 8}};
    migraph::shape s2 = s1; // NOLINT
Paul's avatar
Paul committed
12
13
14
15
    EXPECT(s1 == s2);
    EXPECT(!(s1 != s2));
}

Paul's avatar
Paul committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
void test_shape_packed_default()
{
    migraph::shape s{migraph::shape::float_type, {2, 2}};
    EXPECT(s.packed());
}

void test_shape_packed()
{
    migraph::shape s{migraph::shape::float_type, {2, 2}, {2, 1}};
    EXPECT(s.packed());
}

void test_shape_transposed()
{
    migraph::shape s{migraph::shape::float_type, {2, 2}, {1, 2}};
    EXPECT(not s.packed());
}

Paul's avatar
Paul committed
34
35
void test_shape_default()
{
Paul's avatar
Paul committed
36
37
    migraph::shape s1{};
    migraph::shape s2{};
Paul's avatar
Paul committed
38
39
40
41
    EXPECT(s1 == s2);
    EXPECT(!(s1 != s2));
}

Paul's avatar
Paul committed
42
43
void test_shape4()
{
Paul's avatar
Paul committed
44
    migraph::shape s{migraph::shape::float_type, {100, 32, 8, 8}};
Paul's avatar
Paul committed
45
    EXPECT(s.packed());
Paul's avatar
Paul committed
46
    EXPECT(s.type() == migraph::shape::float_type);
Paul's avatar
Paul committed
47
48
49
50
51
52
53
54
    EXPECT(s.lens()[0] == 100);
    EXPECT(s.lens()[1] == 32);
    EXPECT(s.lens()[2] == 8);
    EXPECT(s.lens()[3] == 8);
    EXPECT(s.strides()[0] == s.lens()[1] * s.strides()[1]);
    EXPECT(s.strides()[1] == s.lens()[2] * s.strides()[2]);
    EXPECT(s.strides()[2] == s.lens()[3] * s.strides()[3]);
    EXPECT(s.strides()[3] == 1);
Paul's avatar
Paul committed
55
56
    EXPECT(s.elements() == 100 * 32 * 8 * 8);
    EXPECT(s.bytes() == 100 * 32 * 8 * 8 * sizeof(float));
Paul's avatar
Paul committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    EXPECT(s.index({0, 0, 0, 0}) == 0);
    EXPECT(s.index({0, 0, 0, 1}) == 1);
    EXPECT(s.index({0, 0, 0, 0}) == s.index(0));
    EXPECT(s.index({0, 0, 0, 1}) == s.index(1));
    EXPECT(s.index({0, 0, 1, 0}) == s.index(8));
    EXPECT(s.index({0, 1, 0, 0}) == s.index(8 * 8));
    EXPECT(s.index({1, 0, 0, 0}) == s.index(8 * 8 * 32));
    EXPECT(s.index(0) == 0);
    EXPECT(s.index(1) == 1);
    EXPECT(s.index(8) == 8);
    EXPECT(s.index(8 * 8) == 8 * 8);
    EXPECT(s.index(8 * 8 * 32) == 8 * 8 * 32);
    EXPECT(s.index(s.elements() - 1) == s.elements() - 1);
}

void test_shape4_nonpacked()
{
    std::vector<std::size_t> lens       = {100, 32, 8, 8};
Paul's avatar
Paul committed
75
76
    std::array<std::size_t, 4> offsets  = {{5, 10, 0, 6}};
    std::array<std::size_t, 4> adj_lens = {{0, 0, 0, 0}};
Paul's avatar
Paul committed
77
78
79
80
81
82

    std::transform(
        lens.begin(), lens.end(), offsets.begin(), adj_lens.begin(), std::plus<size_t>());
    // adj_lens should be: { 105, 42, 8, 14 }
    std::vector<std::size_t> strides(4);
    strides.back() = 1;
Paul's avatar
Paul committed
83
84
85
86
    std::partial_sum(adj_lens.rbegin(),
                     adj_lens.rend() - 1,
                     strides.rbegin() + 1,
                     std::multiplies<std::size_t>());
Paul's avatar
Paul committed
87

Paul's avatar
Paul committed
88
    migraph::shape s{migraph::shape::float_type, lens, strides};
Paul's avatar
Paul committed
89
    EXPECT(!s.packed());
Paul's avatar
Paul committed
90
    EXPECT(s.type() == migraph::shape::float_type);
Paul's avatar
Paul committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    EXPECT(s.lens()[0] == 100);
    EXPECT(s.lens()[1] == 32);
    EXPECT(s.lens()[2] == 8);
    EXPECT(s.lens()[3] == 8);
    EXPECT(s.strides()[0] == 4704);
    EXPECT(s.strides()[1] == 112);
    EXPECT(s.strides()[2] == 14);
    EXPECT(s.strides()[3] == 1);
    EXPECT(s.elements() == 100 * 32 * 8 * 8);
    EXPECT(s.bytes() == sizeof(float) * 469274);

    EXPECT(s.index(0) == 0);
    EXPECT(s.index(1) == 1);
    EXPECT(s.index({0, 0, 0, 0}) == 0);
    EXPECT(s.index({0, 0, 0, 1}) == s.index(1));
    // TODO: Fix these tests
    // EXPECT(s.index({0, 0, 1, 0}) == s.index(8));
    // EXPECT(s.index({0, 1, 0, 0}) == s.index(8 * 8));
    // EXPECT(s.index({1, 0, 0, 0}) == s.index(8 * 8 * 32));
    // EXPECT(s.index(s.elements() - 1) == 469273);
Paul's avatar
Paul committed
111
112
}

Paul's avatar
Paul committed
113
114
int main()
{
Paul's avatar
Paul committed
115
    test_shape_assign();
Paul's avatar
Paul committed
116
117
118
    test_shape_packed_default();
    test_shape_packed();
    test_shape_transposed();
Paul's avatar
Paul committed
119
    test_shape_default();
Paul's avatar
Paul committed
120
    test_shape4();
Paul's avatar
Paul committed
121
    test_shape4_nonpacked();
Paul's avatar
Paul committed
122
}