shape_test.cpp 6.94 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
#include <migraphx/shape.hpp>
Paul's avatar
Paul committed
3
4
5
#include <array>
#include <algorithm>
#include <numeric>
Paul's avatar
Paul committed
6
7
#include "test.hpp"

Paul's avatar
Paul committed
8
TEST_CASE(test_shape_default)
9
{
Paul's avatar
Paul committed
10
    migraphx::shape s{};
11
12
13
14
    EXPECT(s.elements() == 0);
    EXPECT(s.bytes() == 0);
}

Paul's avatar
Paul committed
15
TEST_CASE(test_shape_assign)
Paul's avatar
Paul committed
16
{
Paul's avatar
Paul committed
17
18
    migraphx::shape s1{migraphx::shape::float_type, {100, 32, 8, 8}};
    migraphx::shape s2 = s1; // NOLINT
Paul's avatar
Paul committed
19
20
21
22
    EXPECT(s1 == s2);
    EXPECT(!(s1 != s2));
}

Paul's avatar
Paul committed
23
TEST_CASE(test_shape_packed_default)
Paul's avatar
Paul committed
24
{
Paul's avatar
Paul committed
25
    migraphx::shape s{migraphx::shape::float_type, {2, 2}};
Paul's avatar
Paul committed
26
    EXPECT(s.standard());
Paul's avatar
Paul committed
27
    EXPECT(s.packed());
Paul's avatar
Paul committed
28
29
    EXPECT(not s.transposed());
    EXPECT(not s.broadcasted());
Paul's avatar
Paul committed
30
31
}

Paul's avatar
Paul committed
32
TEST_CASE(test_shape_packed)
Paul's avatar
Paul committed
33
{
Paul's avatar
Paul committed
34
    migraphx::shape s{migraphx::shape::float_type, {2, 2}, {2, 1}};
Paul's avatar
Paul committed
35
    EXPECT(s.standard());
Paul's avatar
Paul committed
36
    EXPECT(s.packed());
Paul's avatar
Paul committed
37
38
    EXPECT(not s.transposed());
    EXPECT(not s.broadcasted());
Paul's avatar
Paul committed
39
40
}

Paul's avatar
Paul committed
41
TEST_CASE(test_shape_transposed1)
Paul's avatar
Paul committed
42
{
Paul's avatar
Paul committed
43
    migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 2}};
Paul's avatar
Paul committed
44
45
46
47
48
49
    EXPECT(not s.standard());
    EXPECT(s.packed());
    EXPECT(s.transposed());
    EXPECT(not s.broadcasted());
}

Paul's avatar
Paul committed
50
51
52
53
54
55
56
57
58
TEST_CASE(test_shape_transposed2)
{
    migraphx::shape s{migraphx::shape::float_type, {1, 1, 1, 1, 2}, {2, 2, 2, 2, 1}};
    EXPECT(s.standard());
    EXPECT(s.packed());
    EXPECT(not s.transposed());
    EXPECT(not s.broadcasted());
}

Paul's avatar
Paul committed
59
TEST_CASE(test_shape_broadcasted)
Paul's avatar
Paul committed
60
{
Paul's avatar
Paul committed
61
    migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 0}};
Paul's avatar
Paul committed
62
    EXPECT(not s.standard());
Paul's avatar
Paul committed
63
    EXPECT(not s.packed());
Paul's avatar
Paul committed
64
65
    EXPECT(not s.transposed());
    EXPECT(s.broadcasted());
Paul's avatar
Paul committed
66
67
}

Paul's avatar
Paul committed
68
TEST_CASE(test_shape_default_copy)
Paul's avatar
Paul committed
69
{
Paul's avatar
Paul committed
70
71
    migraphx::shape s1{};
    migraphx::shape s2{};
Paul's avatar
Paul committed
72
73
74
75
    EXPECT(s1 == s2);
    EXPECT(!(s1 != s2));
}

Paul's avatar
Paul committed
76
TEST_CASE(test_shape4)
Paul's avatar
Paul committed
77
{
Paul's avatar
Paul committed
78
    migraphx::shape s{migraphx::shape::float_type, {100, 32, 8, 8}};
Paul's avatar
Paul committed
79
    EXPECT(s.standard());
Paul's avatar
Paul committed
80
    EXPECT(s.packed());
Paul's avatar
Paul committed
81
82
    EXPECT(not s.transposed());
    EXPECT(not s.broadcasted());
Paul's avatar
Paul committed
83
    EXPECT(s.type() == migraphx::shape::float_type);
Paul's avatar
Paul committed
84
85
86
87
88
89
90
91
    EXPECT(s.lens()[0] == 100);
    EXPECT(s.lens()[1] == 32);
    EXPECT(s.lens()[2] == 8);
    EXPECT(s.lens()[3] == 8);
    EXPECT(s.strides()[0] == s.lens()[1] * s.strides()[1]);
    EXPECT(s.strides()[1] == s.lens()[2] * s.strides()[2]);
    EXPECT(s.strides()[2] == s.lens()[3] * s.strides()[3]);
    EXPECT(s.strides()[3] == 1);
Paul's avatar
Paul committed
92
93
    EXPECT(s.elements() == 100 * 32 * 8 * 8);
    EXPECT(s.bytes() == 100 * 32 * 8 * 8 * sizeof(float));
Paul's avatar
Paul committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    EXPECT(s.index({0, 0, 0, 0}) == 0);
    EXPECT(s.index({0, 0, 0, 1}) == 1);
    EXPECT(s.index({0, 0, 0, 0}) == s.index(0));
    EXPECT(s.index({0, 0, 0, 1}) == s.index(1));
    EXPECT(s.index({0, 0, 1, 0}) == s.index(8));
    EXPECT(s.index({0, 1, 0, 0}) == s.index(8 * 8));
    EXPECT(s.index({1, 0, 0, 0}) == s.index(8 * 8 * 32));
    EXPECT(s.index(0) == 0);
    EXPECT(s.index(1) == 1);
    EXPECT(s.index(8) == 8);
    EXPECT(s.index(8 * 8) == 8 * 8);
    EXPECT(s.index(8 * 8 * 32) == 8 * 8 * 32);
    EXPECT(s.index(s.elements() - 1) == s.elements() - 1);
}

Paul's avatar
Paul committed
109
TEST_CASE(test_shape42)
Paul's avatar
Paul committed
110
{
Paul's avatar
Paul committed
111
    migraphx::shape s{migraphx::shape::float_type, {100, 32, 8, 8}, {2048, 64, 8, 1}};
Paul's avatar
Paul committed
112
113
114
115
    EXPECT(s.standard());
    EXPECT(s.packed());
    EXPECT(not s.transposed());
    EXPECT(not s.broadcasted());
Paul's avatar
Paul committed
116
    EXPECT(s.type() == migraphx::shape::float_type);
Paul's avatar
Paul committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    EXPECT(s.lens()[0] == 100);
    EXPECT(s.lens()[1] == 32);
    EXPECT(s.lens()[2] == 8);
    EXPECT(s.lens()[3] == 8);
    EXPECT(s.strides()[0] == s.lens()[1] * s.strides()[1]);
    EXPECT(s.strides()[1] == s.lens()[2] * s.strides()[2]);
    EXPECT(s.strides()[2] == s.lens()[3] * s.strides()[3]);
    EXPECT(s.strides()[3] == 1);
    EXPECT(s.elements() == 100 * 32 * 8 * 8);
    EXPECT(s.bytes() == 100 * 32 * 8 * 8 * sizeof(float));
    EXPECT(s.index({0, 0, 0, 0}) == 0);
    EXPECT(s.index({0, 0, 0, 1}) == 1);
    EXPECT(s.index({0, 0, 0, 0}) == s.index(0));
    EXPECT(s.index({0, 0, 0, 1}) == s.index(1));
    EXPECT(s.index({0, 0, 1, 0}) == s.index(8));
    EXPECT(s.index({0, 1, 0, 0}) == s.index(8 * 8));
    EXPECT(s.index({1, 0, 0, 0}) == s.index(8 * 8 * 32));
    EXPECT(s.index(0) == 0);
    EXPECT(s.index(1) == 1);
    EXPECT(s.index(8) == 8);
    EXPECT(s.index(8 * 8) == 8 * 8);
    EXPECT(s.index(8 * 8 * 32) == 8 * 8 * 32);
    EXPECT(s.index(s.elements() - 1) == s.elements() - 1);
}

Paul's avatar
Paul committed
142
TEST_CASE(test_shape4_transposed)
Paul's avatar
Paul committed
143
{
Paul's avatar
Paul committed
144
    migraphx::shape s{migraphx::shape::float_type, {32, 100, 8, 8}, {64, 2048, 8, 1}};
Paul's avatar
Paul committed
145
146
147
148
    EXPECT(s.transposed());
    EXPECT(s.packed());
    EXPECT(not s.standard());
    EXPECT(not s.broadcasted());
Paul's avatar
Paul committed
149
    EXPECT(s.type() == migraphx::shape::float_type);
Paul's avatar
Paul committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    EXPECT(s.lens()[0] == 32);
    EXPECT(s.lens()[1] == 100);
    EXPECT(s.lens()[2] == 8);
    EXPECT(s.lens()[3] == 8);
    EXPECT(s.strides()[0] == 64);
    EXPECT(s.strides()[1] == 2048);
    EXPECT(s.strides()[2] == 8);
    EXPECT(s.strides()[3] == 1);
    EXPECT(s.elements() == 100 * 32 * 8 * 8);
    EXPECT(s.bytes() == 100 * 32 * 8 * 8 * sizeof(float));
    EXPECT(s.index({0, 0, 0, 0}) == 0);
    EXPECT(s.index({0, 0, 0, 1}) == 1);
    EXPECT(s.index({0, 0, 0, 0}) == s.index(0));
    EXPECT(s.index({0, 0, 0, 1}) == s.index(1));
    EXPECT(s.index({0, 0, 1, 0}) == s.index(8));
    EXPECT(s.index({0, 1, 0, 0}) == s.index(8 * 8));
    EXPECT(s.index({1, 0, 0, 0}) == s.index(8 * 8 * 100));
    EXPECT(s.index(0) == 0);
    EXPECT(s.index(1) == 1);
    EXPECT(s.index(8) == 8);
    EXPECT(s.index(8 * 8) == 2048);
    EXPECT(s.index(8 * 8 * 100) == 64);
    EXPECT(s.index(s.elements() - 1) == s.elements() - 1);
}

Paul's avatar
Paul committed
175
TEST_CASE(test_shape4_nonpacked)
Paul's avatar
Paul committed
176
177
{
    std::vector<std::size_t> lens       = {100, 32, 8, 8};
Paul's avatar
Paul committed
178
179
    std::array<std::size_t, 4> offsets  = {{5, 10, 0, 6}};
    std::array<std::size_t, 4> adj_lens = {{0, 0, 0, 0}};
Paul's avatar
Paul committed
180
181
182
183
184
185

    std::transform(
        lens.begin(), lens.end(), offsets.begin(), adj_lens.begin(), std::plus<size_t>());
    // adj_lens should be: { 105, 42, 8, 14 }
    std::vector<std::size_t> strides(4);
    strides.back() = 1;
Paul's avatar
Paul committed
186
187
188
189
    std::partial_sum(adj_lens.rbegin(),
                     adj_lens.rend() - 1,
                     strides.rbegin() + 1,
                     std::multiplies<std::size_t>());
Paul's avatar
Paul committed
190

Paul's avatar
Paul committed
191
    migraphx::shape s{migraphx::shape::float_type, lens, strides};
Paul's avatar
Paul committed
192
193
194
195
    EXPECT(not s.standard());
    EXPECT(not s.packed());
    EXPECT(not s.transposed());
    EXPECT(not s.broadcasted());
Paul's avatar
Paul committed
196
    EXPECT(s.type() == migraphx::shape::float_type);
Paul's avatar
Paul committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    EXPECT(s.lens()[0] == 100);
    EXPECT(s.lens()[1] == 32);
    EXPECT(s.lens()[2] == 8);
    EXPECT(s.lens()[3] == 8);
    EXPECT(s.strides()[0] == 4704);
    EXPECT(s.strides()[1] == 112);
    EXPECT(s.strides()[2] == 14);
    EXPECT(s.strides()[3] == 1);
    EXPECT(s.elements() == 100 * 32 * 8 * 8);
    EXPECT(s.bytes() == sizeof(float) * 469274);

    EXPECT(s.index(0) == 0);
    EXPECT(s.index(1) == 1);
    EXPECT(s.index({0, 0, 0, 0}) == 0);
    EXPECT(s.index({0, 0, 0, 1}) == s.index(1));
Paul's avatar
Paul committed
212
213
214
215
    EXPECT(s.index({0, 0, 1, 0}) == s.index(8));
    EXPECT(s.index({0, 1, 0, 0}) == s.index(8 * 8));
    EXPECT(s.index({1, 0, 0, 0}) == s.index(8 * 8 * 32));
    EXPECT(s.index(s.elements() - 1) == 469273);
Paul's avatar
Paul committed
216
217
}

Paul's avatar
Paul committed
218
int main(int argc, const char* argv[]) { test::run(argc, argv); }