simplify_reshapes_test.cpp 17.4 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/pass_manager.hpp>
Paul's avatar
Paul committed
4
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/instruction.hpp>
Paul Fultz II's avatar
Paul Fultz II committed
6
#include <migraphx/generate.hpp>
Paul's avatar
Paul committed
7
8
9
#include <basic_ops.hpp>
#include <test.hpp>

10
void run_pass(migraphx::program& p)
Paul's avatar
Paul committed
11
{
12
13
    migraphx::run_passes(p, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}});
}
Paul's avatar
Paul committed
14

Paul's avatar
Paul committed
15
TEST_CASE(double_contig)
Paul's avatar
Paul committed
16
{
Paul's avatar
Paul committed
17
    migraphx::program p;
Paul's avatar
Paul committed
18
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
19
20
21
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
Paul's avatar
Paul committed
22
    p.add_instruction(pass_op{}, c2);
23
24
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
25
    run_pass(p);
26
27
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
28
    EXPECT(std::distance(p.begin(), p.end()) == 4);
29
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
30
    EXPECT(result != get_2x2());
Paul's avatar
Paul committed
31
32
}

Paul's avatar
Paul committed
33
TEST_CASE(double_transpose)
Paul's avatar
Paul committed
34
{
Paul's avatar
Paul committed
35
    migraphx::program p;
Paul's avatar
Paul committed
36
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
37
38
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
Paul's avatar
Paul committed
39
    p.add_instruction(pass_op{}, t2);
40
41
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
42
    run_pass(p);
43
44
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
45
    EXPECT(std::distance(p.begin(), p.end()) == 2);
46
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
47
48
49
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
50
TEST_CASE(double_transpose_contig)
Paul's avatar
Paul committed
51
{
Paul's avatar
Paul committed
52
    migraphx::program p;
Paul's avatar
Paul committed
53
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
54
55
56
57
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, c1);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, t2);
Paul's avatar
Paul committed
58
    p.add_instruction(pass_op{}, c2);
59
60
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
61
    run_pass(p);
62
63
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
64
    EXPECT(std::distance(p.begin(), p.end()) == 2);
65
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
66
67
68
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
69
TEST_CASE(single_transpose)
Paul's avatar
Paul committed
70
{
Paul's avatar
Paul committed
71
    migraphx::program p;
Paul's avatar
Paul committed
72
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
73
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
Paul's avatar
Paul committed
74
    p.add_instruction(pass_op{}, t1);
75
76
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
77
    run_pass(p);
78
79
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
80
    EXPECT(std::distance(p.begin(), p.end()) == 3);
81
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
82
83
84
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
85
TEST_CASE(double_transpose_sin_pass)
Paul's avatar
Paul committed
86
{
Paul's avatar
Paul committed
87
    migraphx::program p;
Paul's avatar
Paul committed
88
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
89
90
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
91
92
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
93
    run_pass(p);
94
95
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
96
97
    // TODO: Fix this
    // EXPECT(std::distance(p.begin(), p.end()) == 1);
98
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
99
100
101
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
102
TEST_CASE(single_transpose_sin_pass)
Paul's avatar
Paul committed
103
{
Paul's avatar
Paul committed
104
    migraphx::program p;
Paul's avatar
Paul committed
105
    auto l = p.add_literal(get_2x2());
Paul's avatar
Paul committed
106
    p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
107
108
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
109
    run_pass(p);
110
111
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
112
    EXPECT(std::distance(p.begin(), p.end()) == 2);
113
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
114
115
116
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
117
118
119
TEST_CASE(reshape_transpose)
{
    migraphx::program p;
Paul's avatar
Paul committed
120
121
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
    auto x  = p.add_parameter("x", s);
Paul's avatar
Paul committed
122
    auto r1 = p.add_instruction(migraphx::op::reshape{{1, 4, 28, 56, 56}}, x);
Paul's avatar
Paul committed
123
    auto t  = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3, 4}}, r1);
Paul's avatar
Paul committed
124
125
126
    auto ct = p.add_instruction(migraphx::op::contiguous{}, t);
    auto r2 = p.add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct);
    p.add_instruction(pass_op{}, r2);
127
    EXPECT(p.get_output_shapes().back() == s);
Paul's avatar
Paul committed
128
    auto n = std::distance(p.begin(), p.end());
129
    run_pass(p);
130
    EXPECT(p.get_output_shapes().back() == s);
Paul's avatar
Paul committed
131
132
133
    EXPECT(std::distance(p.begin(), p.end()) == n);
}

Paul's avatar
Paul committed
134
135
136
137
138
139
140
141
TEST_CASE(transpose_contiguous)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x  = p.add_parameter("x", s);
    auto t  = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
    p.add_instruction(pass_op{}, c1);
142
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
143
    auto n         = std::distance(p.begin(), p.end());
144
    run_pass(p);
145
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
146
147
148
149
150
151
152
153
154
155
156
157
    EXPECT(std::distance(p.begin(), p.end()) == n);
}

TEST_CASE(transpose_double_contiguous)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x  = p.add_parameter("x", s);
    auto t  = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
    p.add_instruction(pass_op{}, c2);
158
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
159
    auto n         = std::distance(p.begin(), p.end());
160
    run_pass(p);
161
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
162
163
164
165
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
    EXPECT(p.has_instruction(t));
}

166
167
168
169
170
TEST_CASE(transpose_partial1)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x  = p.add_parameter("x", s);
Paul's avatar
Paul committed
171
172
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
173
    p.add_instruction(pass_op{}, t2);
174
    auto out_shape = p.get_output_shapes().back();
175
    auto n         = std::distance(p.begin(), p.end());
176
    run_pass(p);
177
    EXPECT(p.get_output_shapes().back() == out_shape);
178
179
180
181
182
183
184
185
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}

TEST_CASE(transpose_partial2)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x  = p.add_parameter("x", s);
Paul's avatar
Paul committed
186
187
188
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
    auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
189
    p.add_instruction(pass_op{}, t3);
190
    auto out_shape = p.get_output_shapes().back();
191
    auto n         = std::distance(p.begin(), p.end());
192
    run_pass(p);
193
    EXPECT(p.get_output_shapes().back() == out_shape);
194
195
196
197
198
199
200
201
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
}

TEST_CASE(transpose_partial3)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x  = p.add_parameter("x", s);
Paul's avatar
Paul committed
202
203
204
205
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
    auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
    auto t4 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3);
206
    p.add_instruction(pass_op{}, t4);
207
    auto out_shape = p.get_output_shapes().back();
208
    auto n         = std::distance(p.begin(), p.end());
209
    run_pass(p);
210
    EXPECT(p.get_output_shapes().back() == out_shape);
211
212
213
    EXPECT(std::distance(p.begin(), p.end()) == n - 3);
}

Paul's avatar
Paul committed
214
215
216
TEST_CASE(nop_transpose1)
{
    migraphx::program p;
Paul's avatar
Paul committed
217
218
    auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x = p.add_parameter("x", s);
Paul's avatar
Paul committed
219
220
    auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
    p.add_instruction(pass_op{}, t);
221
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
222
    auto n         = std::distance(p.begin(), p.end());
223
    run_pass(p);
224
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
225
226
227
228
229
230
231
232
233
234
235
236
237
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}

TEST_CASE(nop_transpose2)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x  = p.add_parameter("x", s);
    auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
    auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t1);
    auto t3 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t2);
    auto t4 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t3);
    p.add_instruction(pass_op{}, t4);
238
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
239
    auto n         = std::distance(p.begin(), p.end());
240
    run_pass(p);
241
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
242
243
244
245
246
247
    EXPECT(std::distance(p.begin(), p.end()) == n - 4);
}

TEST_CASE(nop_transpose3)
{
    migraphx::program p;
Paul's avatar
Paul committed
248
249
250
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x      = p.add_parameter("x", s);
    auto y      = p.add_parameter("y", s);
Paul's avatar
Paul committed
251
    auto concat = p.add_instruction(migraphx::op::concat{3}, x, y);
Paul's avatar
Paul committed
252
253
    auto t1     = p.add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat);
    auto t2     = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1);
Paul's avatar
Paul committed
254
    p.add_instruction(pass_op{}, t2);
255
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
256
    auto n         = std::distance(p.begin(), p.end());
257
    run_pass(p);
258
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
259
260
261
262
263
264
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}

TEST_CASE(concat_transpose1)
{
    migraphx::program p;
Paul's avatar
Paul committed
265
266
267
268
269
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x      = p.add_parameter("x", s);
    auto y      = p.add_parameter("y", s);
    auto xt     = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
    auto yt     = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
Paul's avatar
Paul committed
270
    auto concat = p.add_instruction(migraphx::op::concat{2}, xt, yt);
Paul's avatar
Paul committed
271
    auto t      = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat);
Paul's avatar
Paul committed
272
    p.add_instruction(pass_op{}, t);
273
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
274
    auto n         = std::distance(p.begin(), p.end());
275
    run_pass(p);
276
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Paul's avatar
Paul committed
277
    EXPECT(std::distance(p.begin(), p.end()) == n - 3);
Paul's avatar
Paul committed
278
279
    auto new_concat =
        std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
Paul's avatar
Paul committed
280
281
282
283
    EXPECT(bool{new_concat != p.end()});
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
}

Paul's avatar
Paul committed
284
285
286
287
288
289
290
291
292
293
294
TEST_CASE(concat_transpose2)
{
    migraphx::program p;
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x      = p.add_parameter("x", s);
    auto y      = p.add_parameter("y", s);
    auto xt     = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
    auto yt     = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
    auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
    auto t      = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
    p.add_instruction(pass_op{}, t);
295
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
296
    auto n         = std::distance(p.begin(), p.end());
297
    run_pass(p);
298
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Paul's avatar
Paul committed
299
300
301
302
303
304
305
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    auto new_concat =
        std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
    EXPECT(bool{new_concat != p.end()});
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}

306
307
308
309
310
311
312
313
314
315
316
TEST_CASE(concat_transpose3)
{
    migraphx::program p;
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x      = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}});
    auto y      = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 4}});
    auto xt     = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
    auto yt     = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
    auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
    auto t      = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
    p.add_instruction(pass_op{}, t);
317
    auto out_shape = p.get_output_shapes().back();
318
    auto n         = std::distance(p.begin(), p.end());
319
    run_pass(p);
320
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
321
322
323
324
325
326
327
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    auto new_concat =
        std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
    EXPECT(bool{new_concat != p.end()});
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}

Paul Fultz II's avatar
Paul Fultz II committed
328
329
330
331
332
333
334
335
336
337
TEST_CASE(nested_concat)
{
    migraphx::program p;
    auto s       = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x       = p.add_parameter("x", s);
    auto y       = p.add_parameter("y", s);
    auto concat1 = p.add_instruction(migraphx::op::concat{1}, x, y);
    auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x);
    auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2);
    p.add_instruction(pass_op{}, concat3);
338
    auto out_shape = p.get_output_shapes().back();
Paul Fultz II's avatar
Paul Fultz II committed
339
    auto n         = std::distance(p.begin(), p.end());
340
    run_pass(p);
341
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Paul Fultz II's avatar
Paul Fultz II committed
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
}

TEST_CASE(nested_concat_partial)
{
    migraphx::program p;
    auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x = p.add_parameter("x", s);
    auto y = p.add_parameter("y", s);
    auto l = p.add_literal(
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1, 4, 3, 4}}));
    auto concat1 = p.add_instruction(migraphx::op::concat{1}, x, y);
    auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x);
    auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2, l);
    p.add_instruction(pass_op{}, concat3);
358
    auto out_shape = p.get_output_shapes().back();
Paul Fultz II's avatar
Paul Fultz II committed
359
    auto n         = std::distance(p.begin(), p.end());
360
    run_pass(p);
361
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Paul Fultz II's avatar
Paul Fultz II committed
362
363
364
365
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
}

366
367
368
369
370
371
372
373
374
375
376
377
378
TEST_CASE(multibroadcast_simplify)
{
    migraphx::program p;
    std::vector<size_t> s_lens{1, 2, 3, 4};
    auto s = migraphx::shape{migraphx::shape::float_type, s_lens};
    auto x = p.add_parameter("x", s);
    auto y = p.add_instruction(migraphx::op::multibroadcast{s_lens}, x);
    p.add_instruction(migraphx::op::mul{}, y, y);
    auto n = std::distance(p.begin(), p.end());
    run_pass(p);
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}

379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
TEST_CASE(double_slice1)
{
    migraphx::program p1;
    {
        auto x      = p1.add_parameter("x", {migraphx::shape::int32_type, {256}});
        auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {256}}, x);
        auto slice2 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {64}}, slice1);
        p1.add_instruction(pass_op{}, slice2);
    }
    run_pass(p1);

    migraphx::program p2;
    {
        auto x     = p2.add_parameter("x", {migraphx::shape::int32_type, {256}});
        auto slice = p2.add_instruction(migraphx::op::slice{{0}, {64}, {96}}, x);
        p2.add_instruction(pass_op{}, slice);
    }
    EXPECT(p1 == p2);
}

TEST_CASE(double_slice2)
{
    migraphx::program p1;
    {
        auto x      = p1.add_parameter("x", {migraphx::shape::int32_type, {256}});
        auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x);
        auto slice2 = p1.add_instruction(migraphx::op::slice{{0}, {0}, {32}}, slice1);
        p1.add_instruction(pass_op{}, slice2);
    }
    run_pass(p1);

    migraphx::program p2;
    {
        auto x     = p2.add_parameter("x", {migraphx::shape::int32_type, {256}});
        auto slice = p2.add_instruction(migraphx::op::slice{{0}, {32}, {64}}, x);
        p2.add_instruction(pass_op{}, slice);
    }
    EXPECT(p1 == p2);
}

TEST_CASE(double_slice_multi_axes)
{
    migraphx::program p1;
    {
        auto x      = p1.add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
        auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x);
        auto slice2 = p1.add_instruction(migraphx::op::slice{{1}, {0}, {32}}, slice1);
        p1.add_instruction(pass_op{}, slice2);
    }
    run_pass(p1);

    migraphx::program p2;
    {
        auto x     = p2.add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
        auto slice = p2.add_instruction(migraphx::op::slice{{0, 1}, {32, 0}, {128, 32}}, x);
        p2.add_instruction(pass_op{}, slice);
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
439
int main(int argc, const char* argv[]) { test::run(argc, argv); }