simplify_reshapes_test.cpp 19.7 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/pass_manager.hpp>
Paul's avatar
Paul committed
4
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/instruction.hpp>
Paul Fultz II's avatar
Paul Fultz II committed
6
#include <migraphx/generate.hpp>
Paul's avatar
Paul committed
7
8
9
#include <basic_ops.hpp>
#include <test.hpp>

10
void run_pass(migraphx::program& p)
Paul's avatar
Paul committed
11
{
12
13
    auto* mm = p.get_main_module();
    migraphx::run_passes(*mm, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}});
14
}
Paul's avatar
Paul committed
15

Paul's avatar
Paul committed
16
TEST_CASE(double_contig)
Paul's avatar
Paul committed
17
{
Paul's avatar
Paul committed
18
    migraphx::program p;
19
20
21
22
23
24
25

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
    auto t1  = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c1  = mm->add_instruction(migraphx::op::contiguous{}, t1);
    auto c2  = mm->add_instruction(migraphx::op::contiguous{}, c1);
    mm->add_return({c2});
26
27
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
28
    run_pass(p);
29
30
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
31
    EXPECT(std::distance(p.begin(), p.end()) == 4);
32
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
33
    EXPECT(result != get_2x2());
Paul's avatar
Paul committed
34
35
}

Paul's avatar
Paul committed
36
TEST_CASE(double_transpose)
Paul's avatar
Paul committed
37
{
Paul's avatar
Paul committed
38
    migraphx::program p;
39
40
41
42
43
44

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
    auto t1  = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto t2  = mm->add_instruction(migraphx::op::transpose{{1, 0}}, t1);
    mm->add_return({t2});
45
46
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
47
    run_pass(p);
48
49
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
50
    EXPECT(std::distance(p.begin(), p.end()) == 2);
51
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
52
53
54
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
55
TEST_CASE(double_transpose_contig)
Paul's avatar
Paul committed
56
{
Paul's avatar
Paul committed
57
    migraphx::program p;
58
59
60
61
62
63
64
65

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
    auto t1  = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c1  = mm->add_instruction(migraphx::op::contiguous{}, t1);
    auto t2  = mm->add_instruction(migraphx::op::transpose{{1, 0}}, c1);
    auto c2  = mm->add_instruction(migraphx::op::contiguous{}, t2);
    mm->add_return({c2});
66
67
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
68
    run_pass(p);
69
70
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
71
    EXPECT(std::distance(p.begin(), p.end()) == 2);
72
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
73
74
75
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
76
TEST_CASE(single_transpose)
Paul's avatar
Paul committed
77
{
Paul's avatar
Paul committed
78
    migraphx::program p;
79
80
81
82
83

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
    auto t1  = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
    mm->add_return({t1});
84
85
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
86
    run_pass(p);
87
88
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
89
    EXPECT(std::distance(p.begin(), p.end()) == 3);
90
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
91
92
93
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
94
TEST_CASE(double_transpose_sin_pass)
Paul's avatar
Paul committed
95
{
Paul's avatar
Paul committed
96
    migraphx::program p;
97
98
99
100
101

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
    auto t1  = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
    mm->add_instruction(migraphx::op::transpose{{1, 0}}, t1);
102
103
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
104
    run_pass(p);
105
106
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
107
108
    // TODO: Fix this
    // EXPECT(std::distance(p.begin(), p.end()) == 1);
109
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
110
111
112
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
113
TEST_CASE(single_transpose_sin_pass)
Paul's avatar
Paul committed
114
{
Paul's avatar
Paul committed
115
    migraphx::program p;
116
117
118
119

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
    mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
120
121
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
122
    run_pass(p);
123
124
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
125
    EXPECT(std::distance(p.begin(), p.end()) == 2);
126
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
127
128
129
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
130
131
132
TEST_CASE(reshape_transpose)
{
    migraphx::program p;
133
134
135
136
137
138
139
140
141

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
    auto x   = mm->add_parameter("x", s);
    auto r1  = mm->add_instruction(migraphx::op::reshape{{1, 4, 28, 56, 56}}, x);
    auto t   = mm->add_instruction(migraphx::op::transpose{{0, 2, 1, 3, 4}}, r1);
    auto ct  = mm->add_instruction(migraphx::op::contiguous{}, t);
    auto r2  = mm->add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct);
    mm->add_return({r2});
142
    EXPECT(p.get_output_shapes().back() == s);
Paul's avatar
Paul committed
143
    auto n = std::distance(p.begin(), p.end());
144
    run_pass(p);
145
    EXPECT(p.get_output_shapes().back() == s);
Paul's avatar
Paul committed
146
147
148
    EXPECT(std::distance(p.begin(), p.end()) == n);
}

Paul's avatar
Paul committed
149
150
151
TEST_CASE(transpose_contiguous)
{
    migraphx::program p;
152
153
154
155
156
157
158

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x   = mm->add_parameter("x", s);
    auto t   = mm->add_instruction(migraphx::op::transpose{{1, 0}}, x);
    auto c1  = mm->add_instruction(migraphx::op::contiguous{}, t);
    mm->add_return({c1});
159
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
160
    auto n         = std::distance(p.begin(), p.end());
161
    run_pass(p);
162
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
163
164
165
166
167
168
    EXPECT(std::distance(p.begin(), p.end()) == n);
}

TEST_CASE(transpose_double_contiguous)
{
    migraphx::program p;
169
170
171
172
173
174
175
176

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x   = mm->add_parameter("x", s);
    auto t   = mm->add_instruction(migraphx::op::transpose{{1, 0}}, x);
    auto c1  = mm->add_instruction(migraphx::op::contiguous{}, t);
    auto c2  = mm->add_instruction(migraphx::op::contiguous{}, c1);
    mm->add_return({c2});
177
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
178
    auto n         = std::distance(p.begin(), p.end());
179
    run_pass(p);
180
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
181
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
182
    EXPECT(mm->has_instruction(t));
Paul's avatar
Paul committed
183
184
}

185
186
187
TEST_CASE(transpose_partial1)
{
    migraphx::program p;
188
189
190
191
192
193
194

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x   = mm->add_parameter("x", s);
    auto t1  = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
    auto t2  = mm->add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
    mm->add_return({t2});
195
    auto out_shape = p.get_output_shapes().back();
196
    auto n         = std::distance(p.begin(), p.end());
197
    run_pass(p);
198
    EXPECT(p.get_output_shapes().back() == out_shape);
199
200
201
202
203
204
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}

TEST_CASE(transpose_partial2)
{
    migraphx::program p;
205
206
207
208
209
210
211
212

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x   = mm->add_parameter("x", s);
    auto t1  = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
    auto t2  = mm->add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
    auto t3  = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
    mm->add_return({t3});
213
    auto out_shape = p.get_output_shapes().back();
214
    auto n         = std::distance(p.begin(), p.end());
215
    run_pass(p);
216
    EXPECT(p.get_output_shapes().back() == out_shape);
217
218
219
220
221
222
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
}

TEST_CASE(transpose_partial3)
{
    migraphx::program p;
223
224
225
226
227
228
229
230
231

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x   = mm->add_parameter("x", s);
    auto t1  = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
    auto t2  = mm->add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
    auto t3  = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
    auto t4  = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3);
    mm->add_return({t4});
232
    auto out_shape = p.get_output_shapes().back();
233
    auto n         = std::distance(p.begin(), p.end());
234
    run_pass(p);
235
    EXPECT(p.get_output_shapes().back() == out_shape);
236
237
238
    EXPECT(std::distance(p.begin(), p.end()) == n - 3);
}

Paul's avatar
Paul committed
239
240
241
TEST_CASE(nop_transpose1)
{
    migraphx::program p;
242
243
244
245
246
247

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x   = mm->add_parameter("x", s);
    auto t   = mm->add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
    mm->add_return({t});
248
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
249
    auto n         = std::distance(p.begin(), p.end());
250
    run_pass(p);
251
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
252
253
254
255
256
257
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}

TEST_CASE(nop_transpose2)
{
    migraphx::program p;
258
259
260
261
262
263
264
265
266

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x   = mm->add_parameter("x", s);
    auto t1  = mm->add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
    auto t2  = mm->add_instruction(migraphx::op::transpose{{0, 1, 2}}, t1);
    auto t3  = mm->add_instruction(migraphx::op::transpose{{0, 1, 2}}, t2);
    auto t4  = mm->add_instruction(migraphx::op::transpose{{0, 1, 2}}, t3);
    mm->add_instruction(pass_op{}, t4);
267
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
268
    auto n         = std::distance(p.begin(), p.end());
269
    run_pass(p);
270
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
271
272
273
274
275
276
    EXPECT(std::distance(p.begin(), p.end()) == n - 4);
}

TEST_CASE(nop_transpose3)
{
    migraphx::program p;
277
278

    auto* mm    = p.get_main_module();
Paul's avatar
Paul committed
279
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
280
281
282
283
284
285
    auto x      = mm->add_parameter("x", s);
    auto y      = mm->add_parameter("y", s);
    auto concat = mm->add_instruction(migraphx::op::concat{3}, x, y);
    auto t1     = mm->add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat);
    auto t2     = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1);
    mm->add_return({t2});
286
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
287
    auto n         = std::distance(p.begin(), p.end());
288
    run_pass(p);
289
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
290
291
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
Paul Fultz II's avatar
Paul Fultz II committed
292
293
294
295

TEST_CASE(nop_convert)
{
    migraphx::program p;
296
297
298
299
300
301

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x   = mm->add_parameter("x", s);
    auto t   = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, x);
    mm->add_return({t});
Paul Fultz II's avatar
Paul Fultz II committed
302
303
304
305
306
307
    auto out_shape = p.get_output_shapes().back();
    auto n         = std::distance(p.begin(), p.end());
    run_pass(p);
    EXPECT(p.get_output_shapes().back() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
Paul's avatar
Paul committed
308
309
310
311

TEST_CASE(concat_transpose1)
{
    migraphx::program p;
312
313

    auto* mm    = p.get_main_module();
Paul's avatar
Paul committed
314
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
315
316
317
318
319
320
321
    auto x      = mm->add_parameter("x", s);
    auto y      = mm->add_parameter("y", s);
    auto xt     = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
    auto yt     = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
    auto concat = mm->add_instruction(migraphx::op::concat{2}, xt, yt);
    auto t      = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat);
    mm->add_return({t});
322
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
323
    auto n         = std::distance(p.begin(), p.end());
324
    run_pass(p);
325
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Paul's avatar
Paul committed
326
    EXPECT(std::distance(p.begin(), p.end()) == n - 3);
Paul's avatar
Paul committed
327
328
    auto new_concat =
        std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
Paul's avatar
Paul committed
329
330
331
332
    EXPECT(bool{new_concat != p.end()});
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
}

Paul's avatar
Paul committed
333
334
335
TEST_CASE(concat_transpose2)
{
    migraphx::program p;
336
337

    auto* mm    = p.get_main_module();
Paul's avatar
Paul committed
338
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
339
340
341
342
343
344
345
    auto x      = mm->add_parameter("x", s);
    auto y      = mm->add_parameter("y", s);
    auto xt     = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
    auto yt     = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
    auto concat = mm->add_instruction(migraphx::op::concat{-1}, xt, yt);
    auto t      = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
    mm->add_return({t});
346
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
347
    auto n         = std::distance(p.begin(), p.end());
348
    run_pass(p);
349
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Paul's avatar
Paul committed
350
351
352
353
354
355
356
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    auto new_concat =
        std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
    EXPECT(bool{new_concat != p.end()});
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}

357
358
359
TEST_CASE(concat_transpose3)
{
    migraphx::program p;
360
361
362
363
364
365
366
367
368
369

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x   = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}});
    auto y   = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 4}});
    auto xt  = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
    auto yt  = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
    auto concat = mm->add_instruction(migraphx::op::concat{3}, xt, yt);
    auto t      = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
    mm->add_return({t});
370
    auto out_shape = p.get_output_shapes().back();
371
    auto n         = std::distance(p.begin(), p.end());
372
    run_pass(p);
373
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
374
375
376
377
378
379
380
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    auto new_concat =
        std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
    EXPECT(bool{new_concat != p.end()});
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}

Shucai Xiao's avatar
Shucai Xiao committed
381
382
383
TEST_CASE(concat_transpose4)
{
    migraphx::program p;
384
    auto* mm    = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
385
386
    auto sx     = migraphx::shape{migraphx::shape::float_type, {1, 1, 12, 64}};
    auto sy     = migraphx::shape{migraphx::shape::float_type, {1, 12, 1, 64}};
387
388
389
390
391
392
393
    auto x      = mm->add_parameter("x", sx);
    auto y      = mm->add_parameter("y", sy);
    auto xt     = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
    auto yt     = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
    auto concat = mm->add_instruction(migraphx::op::concat{3}, xt, yt);
    auto t      = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
    mm->add_return({t});
Shucai Xiao's avatar
Shucai Xiao committed
394
395
396

    migraphx::program p1 = p;
    run_pass(p);
397

Shucai Xiao's avatar
Shucai Xiao committed
398
399
400
    EXPECT(p1 == p);
}

Paul Fultz II's avatar
Paul Fultz II committed
401
402
403
TEST_CASE(nested_concat)
{
    migraphx::program p;
404
405

    auto* mm     = p.get_main_module();
Paul Fultz II's avatar
Paul Fultz II committed
406
    auto s       = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
407
408
409
410
411
412
    auto x       = mm->add_parameter("x", s);
    auto y       = mm->add_parameter("y", s);
    auto concat1 = mm->add_instruction(migraphx::op::concat{1}, x, y);
    auto concat2 = mm->add_instruction(migraphx::op::concat{1}, y, x);
    auto concat3 = mm->add_instruction(migraphx::op::concat{1}, concat1, concat2);
    mm->add_return({concat3});
413
    auto out_shape = p.get_output_shapes().back();
Paul Fultz II's avatar
Paul Fultz II committed
414
    auto n         = std::distance(p.begin(), p.end());
415
    run_pass(p);
416
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Paul Fultz II's avatar
Paul Fultz II committed
417
418
419
420
421
422
423
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
}

TEST_CASE(nested_concat_partial)
{
    migraphx::program p;
424
425
426
427
428
429

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x   = mm->add_parameter("x", s);
    auto y   = mm->add_parameter("y", s);
    auto l   = mm->add_literal(
Paul Fultz II's avatar
Paul Fultz II committed
430
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1, 4, 3, 4}}));
431
432
433
434
    auto concat1 = mm->add_instruction(migraphx::op::concat{1}, x, y);
    auto concat2 = mm->add_instruction(migraphx::op::concat{1}, y, x);
    auto concat3 = mm->add_instruction(migraphx::op::concat{1}, concat1, concat2, l);
    mm->add_return({concat3});
435
    auto out_shape = p.get_output_shapes().back();
Paul Fultz II's avatar
Paul Fultz II committed
436
    auto n         = std::distance(p.begin(), p.end());
437
    run_pass(p);
438
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Paul Fultz II's avatar
Paul Fultz II committed
439
440
441
442
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
}

443
444
445
TEST_CASE(multibroadcast_simplify)
{
    migraphx::program p;
446
447

    auto* mm = p.get_main_module();
448
449
    std::vector<size_t> s_lens{1, 2, 3, 4};
    auto s = migraphx::shape{migraphx::shape::float_type, s_lens};
450
451
452
    auto x = mm->add_parameter("x", s);
    auto y = mm->add_instruction(migraphx::op::multibroadcast{s_lens}, x);
    mm->add_instruction(migraphx::op::mul{}, y, y);
453
454
455
456
457
    auto n = std::distance(p.begin(), p.end());
    run_pass(p);
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}

458
459
460
TEST_CASE(double_slice1)
{
    migraphx::program p1;
461
    auto* mm1 = p1.get_main_module();
462
    {
463
464
465
466
        auto x      = mm1->add_parameter("x", {migraphx::shape::int32_type, {256}});
        auto slice1 = mm1->add_instruction(migraphx::op::slice{{0}, {32}, {256}}, x);
        auto slice2 = mm1->add_instruction(migraphx::op::slice{{0}, {32}, {64}}, slice1);
        mm1->add_return({slice2});
467
468
469
470
    }
    run_pass(p1);

    migraphx::program p2;
471
    auto* mm2 = p2.get_main_module();
472
    {
473
474
475
        auto x     = mm2->add_parameter("x", {migraphx::shape::int32_type, {256}});
        auto slice = mm2->add_instruction(migraphx::op::slice{{0}, {64}, {96}}, x);
        mm2->add_return({slice});
476
477
478
479
480
481
482
    }
    EXPECT(p1 == p2);
}

TEST_CASE(double_slice2)
{
    migraphx::program p1;
483
    auto* mm1 = p1.get_main_module();
484
    {
485
486
487
488
        auto x      = mm1->add_parameter("x", {migraphx::shape::int32_type, {256}});
        auto slice1 = mm1->add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x);
        auto slice2 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {32}}, slice1);
        mm1->add_return({slice2});
489
490
491
492
    }
    run_pass(p1);

    migraphx::program p2;
493
    auto* mm2 = p2.get_main_module();
494
    {
495
496
497
        auto x     = mm2->add_parameter("x", {migraphx::shape::int32_type, {256}});
        auto slice = mm2->add_instruction(migraphx::op::slice{{0}, {32}, {64}}, x);
        mm2->add_return({slice});
498
499
500
501
502
503
504
    }
    EXPECT(p1 == p2);
}

TEST_CASE(double_slice_multi_axes)
{
    migraphx::program p1;
505
    auto* mm1 = p1.get_main_module();
506
    {
507
508
509
510
        auto x      = mm1->add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
        auto slice1 = mm1->add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x);
        auto slice2 = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {32}}, slice1);
        mm1->add_return({slice2});
511
512
513
514
    }
    run_pass(p1);

    migraphx::program p2;
515
516

    auto* mm2 = p2.get_main_module();
517
    {
518
519
520
        auto x     = mm2->add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
        auto slice = mm2->add_instruction(migraphx::op::slice{{0, 1}, {32, 0}, {128, 32}}, x);
        mm2->add_return({slice});
521
522
523
524
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
525
int main(int argc, const char* argv[]) { test::run(argc, argv); }