simplify_reshapes_test.cpp 14.2 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/pass_manager.hpp>
Paul's avatar
Paul committed
4
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/instruction.hpp>
Paul Fultz II's avatar
Paul Fultz II committed
6
#include <migraphx/generate.hpp>
Paul's avatar
Paul committed
7
8
9
#include <basic_ops.hpp>
#include <test.hpp>

10
void run_pass(migraphx::program& p)
Paul's avatar
Paul committed
11
{
12
13
    migraphx::run_passes(p, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}});
}
Paul's avatar
Paul committed
14

Paul's avatar
Paul committed
15
TEST_CASE(double_contig)
Paul's avatar
Paul committed
16
{
Paul's avatar
Paul committed
17
    migraphx::program p;
Paul's avatar
Paul committed
18
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
19
20
21
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
Paul's avatar
Paul committed
22
23
24
    p.add_instruction(pass_op{}, c2);
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
25
    run_pass(p);
Paul's avatar
Paul committed
26
27
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
Paul's avatar
Paul committed
28
    EXPECT(std::distance(p.begin(), p.end()) == 4);
Paul's avatar
Paul committed
29
    auto result = p.eval({});
Paul's avatar
Paul committed
30
    EXPECT(result != get_2x2());
Paul's avatar
Paul committed
31
32
}

Paul's avatar
Paul committed
33
TEST_CASE(double_transpose)
Paul's avatar
Paul committed
34
{
Paul's avatar
Paul committed
35
    migraphx::program p;
Paul's avatar
Paul committed
36
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
37
38
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
Paul's avatar
Paul committed
39
40
41
    p.add_instruction(pass_op{}, t2);
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
42
    run_pass(p);
Paul's avatar
Paul committed
43
44
45
46
47
48
49
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 2);
    auto result = p.eval({});
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
50
TEST_CASE(double_transpose_contig)
Paul's avatar
Paul committed
51
{
Paul's avatar
Paul committed
52
    migraphx::program p;
Paul's avatar
Paul committed
53
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
54
55
56
57
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, c1);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, t2);
Paul's avatar
Paul committed
58
59
60
    p.add_instruction(pass_op{}, c2);
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
61
    run_pass(p);
Paul's avatar
Paul committed
62
63
64
65
66
67
68
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 2);
    auto result = p.eval({});
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
69
TEST_CASE(single_transpose)
Paul's avatar
Paul committed
70
{
Paul's avatar
Paul committed
71
    migraphx::program p;
Paul's avatar
Paul committed
72
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
73
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
Paul's avatar
Paul committed
74
75
76
    p.add_instruction(pass_op{}, t1);
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
77
    run_pass(p);
Paul's avatar
Paul committed
78
79
80
81
82
83
84
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 3);
    auto result = p.eval({});
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
85
TEST_CASE(double_transpose_sin_pass)
Paul's avatar
Paul committed
86
{
Paul's avatar
Paul committed
87
    migraphx::program p;
Paul's avatar
Paul committed
88
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
89
90
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
Paul's avatar
Paul committed
91
92
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
93
    run_pass(p);
Paul's avatar
Paul committed
94
95
96
97
98
99
100
101
    EXPECT(p.get_shape().standard());
    EXPECT(not p.get_shape().transposed());
    // TODO: Fix this
    // EXPECT(std::distance(p.begin(), p.end()) == 1);
    auto result = p.eval({});
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
102
TEST_CASE(single_transpose_sin_pass)
Paul's avatar
Paul committed
103
{
Paul's avatar
Paul committed
104
    migraphx::program p;
Paul's avatar
Paul committed
105
    auto l = p.add_literal(get_2x2());
Paul's avatar
Paul committed
106
    p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
Paul's avatar
Paul committed
107
108
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
109
    run_pass(p);
Paul's avatar
Paul committed
110
111
112
113
114
115
116
    EXPECT(not p.get_shape().standard());
    EXPECT(p.get_shape().transposed());
    EXPECT(std::distance(p.begin(), p.end()) == 2);
    auto result = p.eval({});
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
117
118
119
TEST_CASE(reshape_transpose)
{
    migraphx::program p;
Paul's avatar
Paul committed
120
121
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
    auto x  = p.add_parameter("x", s);
Paul's avatar
Paul committed
122
    auto r1 = p.add_instruction(migraphx::op::reshape{{1, 4, 28, 56, 56}}, x);
Paul's avatar
Paul committed
123
    auto t  = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3, 4}}, r1);
Paul's avatar
Paul committed
124
125
126
127
128
    auto ct = p.add_instruction(migraphx::op::contiguous{}, t);
    auto r2 = p.add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct);
    p.add_instruction(pass_op{}, r2);
    EXPECT(p.get_shape() == s);
    auto n = std::distance(p.begin(), p.end());
129
    run_pass(p);
Paul's avatar
Paul committed
130
131
132
133
    EXPECT(p.get_shape() == s);
    EXPECT(std::distance(p.begin(), p.end()) == n);
}

Paul's avatar
Paul committed
134
135
136
137
138
139
140
141
142
TEST_CASE(transpose_contiguous)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x  = p.add_parameter("x", s);
    auto t  = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
    p.add_instruction(pass_op{}, c1);
    auto out_shape = p.get_shape();
Paul's avatar
Paul committed
143
    auto n         = std::distance(p.begin(), p.end());
144
    run_pass(p);
Paul's avatar
Paul committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    EXPECT(p.get_shape() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n);
}

TEST_CASE(transpose_double_contiguous)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x  = p.add_parameter("x", s);
    auto t  = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
    p.add_instruction(pass_op{}, c2);
    auto out_shape = p.get_shape();
Paul's avatar
Paul committed
159
    auto n         = std::distance(p.begin(), p.end());
160
    run_pass(p);
Paul's avatar
Paul committed
161
162
163
164
165
    EXPECT(p.get_shape() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
    EXPECT(p.has_instruction(t));
}

166
167
168
169
170
TEST_CASE(transpose_partial1)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x  = p.add_parameter("x", s);
Paul's avatar
Paul committed
171
172
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
173
174
175
    p.add_instruction(pass_op{}, t2);
    auto out_shape = p.get_shape();
    auto n         = std::distance(p.begin(), p.end());
176
    run_pass(p);
177
178
179
180
181
182
183
184
185
    EXPECT(p.get_shape() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}

TEST_CASE(transpose_partial2)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x  = p.add_parameter("x", s);
Paul's avatar
Paul committed
186
187
188
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
    auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
189
190
191
    p.add_instruction(pass_op{}, t3);
    auto out_shape = p.get_shape();
    auto n         = std::distance(p.begin(), p.end());
192
    run_pass(p);
193
194
195
196
197
198
199
200
201
    EXPECT(p.get_shape() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
}

TEST_CASE(transpose_partial3)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x  = p.add_parameter("x", s);
Paul's avatar
Paul committed
202
203
204
205
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
    auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
    auto t4 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3);
206
207
208
    p.add_instruction(pass_op{}, t4);
    auto out_shape = p.get_shape();
    auto n         = std::distance(p.begin(), p.end());
209
    run_pass(p);
210
211
212
213
    EXPECT(p.get_shape() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n - 3);
}

Paul's avatar
Paul committed
214
215
216
TEST_CASE(nop_transpose1)
{
    migraphx::program p;
Paul's avatar
Paul committed
217
218
    auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x = p.add_parameter("x", s);
Paul's avatar
Paul committed
219
220
221
222
    auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
    p.add_instruction(pass_op{}, t);
    auto out_shape = p.get_shape();
    auto n         = std::distance(p.begin(), p.end());
223
    run_pass(p);
Paul's avatar
Paul committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    EXPECT(p.get_shape() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}

TEST_CASE(nop_transpose2)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x  = p.add_parameter("x", s);
    auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
    auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t1);
    auto t3 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t2);
    auto t4 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t3);
    p.add_instruction(pass_op{}, t4);
    auto out_shape = p.get_shape();
    auto n         = std::distance(p.begin(), p.end());
240
    run_pass(p);
Paul's avatar
Paul committed
241
242
243
244
245
246
247
    EXPECT(p.get_shape() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n - 4);
}

TEST_CASE(nop_transpose3)
{
    migraphx::program p;
Paul's avatar
Paul committed
248
249
250
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x      = p.add_parameter("x", s);
    auto y      = p.add_parameter("y", s);
Paul's avatar
Paul committed
251
    auto concat = p.add_instruction(migraphx::op::concat{3}, x, y);
Paul's avatar
Paul committed
252
253
    auto t1     = p.add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat);
    auto t2     = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1);
Paul's avatar
Paul committed
254
255
256
    p.add_instruction(pass_op{}, t2);
    auto out_shape = p.get_shape();
    auto n         = std::distance(p.begin(), p.end());
257
    run_pass(p);
Paul's avatar
Paul committed
258
259
260
261
262
263
264
    EXPECT(p.get_shape() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}

TEST_CASE(concat_transpose1)
{
    migraphx::program p;
Paul's avatar
Paul committed
265
266
267
268
269
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x      = p.add_parameter("x", s);
    auto y      = p.add_parameter("y", s);
    auto xt     = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
    auto yt     = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
Paul's avatar
Paul committed
270
    auto concat = p.add_instruction(migraphx::op::concat{2}, xt, yt);
Paul's avatar
Paul committed
271
    auto t      = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat);
Paul's avatar
Paul committed
272
273
274
    p.add_instruction(pass_op{}, t);
    auto out_shape = p.get_shape();
    auto n         = std::distance(p.begin(), p.end());
275
    run_pass(p);
Paul's avatar
Paul committed
276
277
    EXPECT(p.get_shape().lens() == out_shape.lens());
    EXPECT(std::distance(p.begin(), p.end()) == n - 3);
Paul's avatar
Paul committed
278
279
    auto new_concat =
        std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
Paul's avatar
Paul committed
280
281
282
283
    EXPECT(bool{new_concat != p.end()});
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
}

Paul's avatar
Paul committed
284
285
286
287
288
289
290
291
292
293
294
295
296
TEST_CASE(concat_transpose2)
{
    migraphx::program p;
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x      = p.add_parameter("x", s);
    auto y      = p.add_parameter("y", s);
    auto xt     = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
    auto yt     = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
    auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
    auto t      = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
    p.add_instruction(pass_op{}, t);
    auto out_shape = p.get_shape();
    auto n         = std::distance(p.begin(), p.end());
297
    run_pass(p);
Paul's avatar
Paul committed
298
299
300
301
302
303
304
305
    EXPECT(p.get_shape().lens() == out_shape.lens());
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    auto new_concat =
        std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
    EXPECT(bool{new_concat != p.end()});
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}

306
307
308
309
310
311
312
313
314
315
316
317
318
TEST_CASE(concat_transpose3)
{
    migraphx::program p;
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x      = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}});
    auto y      = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 4}});
    auto xt     = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
    auto yt     = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
    auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
    auto t      = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
    p.add_instruction(pass_op{}, t);
    auto out_shape = p.get_shape();
    auto n         = std::distance(p.begin(), p.end());
319
    run_pass(p);
320
321
322
323
324
325
326
327
    EXPECT(p.get_shape().lens() == out_shape.lens());
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    auto new_concat =
        std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
    EXPECT(bool{new_concat != p.end()});
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}

Paul Fultz II's avatar
Paul Fultz II committed
328
329
330
331
332
333
334
335
336
337
338
339
TEST_CASE(nested_concat)
{
    migraphx::program p;
    auto s       = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x       = p.add_parameter("x", s);
    auto y       = p.add_parameter("y", s);
    auto concat1 = p.add_instruction(migraphx::op::concat{1}, x, y);
    auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x);
    auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2);
    p.add_instruction(pass_op{}, concat3);
    auto out_shape = p.get_shape();
    auto n         = std::distance(p.begin(), p.end());
340
    run_pass(p);
Paul Fultz II's avatar
Paul Fultz II committed
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
    EXPECT(p.get_shape().lens() == out_shape.lens());
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
}

TEST_CASE(nested_concat_partial)
{
    migraphx::program p;
    auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x = p.add_parameter("x", s);
    auto y = p.add_parameter("y", s);
    auto l = p.add_literal(
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1, 4, 3, 4}}));
    auto concat1 = p.add_instruction(migraphx::op::concat{1}, x, y);
    auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x);
    auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2, l);
    p.add_instruction(pass_op{}, concat3);
    auto out_shape = p.get_shape();
    auto n         = std::distance(p.begin(), p.end());
360
    run_pass(p);
Paul Fultz II's avatar
Paul Fultz II committed
361
362
363
364
365
    EXPECT(p.get_shape().lens() == out_shape.lens());
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
}

Paul's avatar
Paul committed
366
int main(int argc, const char* argv[]) { test::run(argc, argv); }