simplify_reshapes_test.cpp 21.4 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/pass_manager.hpp>
Paul's avatar
Paul committed
4
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/instruction.hpp>
Paul Fultz II's avatar
Paul Fultz II committed
6
#include <migraphx/generate.hpp>
Paul's avatar
Paul committed
7
#include <basic_ops.hpp>
8
9
10
11
#include <migraphx/make_op.hpp>

#include <migraphx/serialize.hpp>

Paul's avatar
Paul committed
12
13
#include <test.hpp>

14
void run_pass(migraphx::program& p)
Paul's avatar
Paul committed
15
{
16
17
    auto* mm = p.get_main_module();
    migraphx::run_passes(*mm, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}});
18
}
Paul's avatar
Paul committed
19

Paul's avatar
Paul committed
20
TEST_CASE(double_contig)
Paul's avatar
Paul committed
21
{
Paul's avatar
Paul committed
22
    migraphx::program p;
23
24
25

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
26
27
28
    auto t1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
    auto c1  = mm->add_instruction(migraphx::make_op("contiguous"), t1);
    auto c2  = mm->add_instruction(migraphx::make_op("contiguous"), c1);
29
    mm->add_return({c2});
30
31
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
32
    run_pass(p);
33
34
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
35
    EXPECT(std::distance(p.begin(), p.end()) == 4);
36
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
37
    EXPECT(result != get_2x2());
Paul's avatar
Paul committed
38
39
}

Paul's avatar
Paul committed
40
TEST_CASE(double_transpose)
Paul's avatar
Paul committed
41
{
Paul's avatar
Paul committed
42
    migraphx::program p;
43
44
45

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
46
47
    auto t1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
    auto t2  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t1);
48
    mm->add_return({t2});
49
50
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
51
    run_pass(p);
52
53
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
54
    EXPECT(std::distance(p.begin(), p.end()) == 2);
55
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
56
57
58
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
59
TEST_CASE(double_transpose_contig)
Paul's avatar
Paul committed
60
{
Paul's avatar
Paul committed
61
    migraphx::program p;
62
63
64

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
65
66
67
68
    auto t1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
    auto c1  = mm->add_instruction(migraphx::make_op("contiguous"), t1);
    auto t2  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), c1);
    auto c2  = mm->add_instruction(migraphx::make_op("contiguous"), t2);
69
    mm->add_return({c2});
70
71
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
72
    run_pass(p);
73
74
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
75
    EXPECT(std::distance(p.begin(), p.end()) == 2);
76
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
77
78
79
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
80
TEST_CASE(single_transpose)
Paul's avatar
Paul committed
81
{
Paul's avatar
Paul committed
82
    migraphx::program p;
83
84
85

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
86
    auto t1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
87
    mm->add_return({t1});
88
89
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
90
    run_pass(p);
91
92
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
93
    EXPECT(std::distance(p.begin(), p.end()) == 3);
94
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
95
96
97
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
98
TEST_CASE(double_transpose_sin_pass)
Paul's avatar
Paul committed
99
{
Paul's avatar
Paul committed
100
    migraphx::program p;
101
102
103

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
104
105
    auto t1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
    mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t1);
106
107
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
108
    run_pass(p);
109
110
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
111
112
    // TODO: Fix this
    // EXPECT(std::distance(p.begin(), p.end()) == 1);
113
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
114
115
116
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
117
TEST_CASE(single_transpose_sin_pass)
Paul's avatar
Paul committed
118
{
Paul's avatar
Paul committed
119
    migraphx::program p;
120
121
122

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
123
    mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
124
125
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
126
    run_pass(p);
127
128
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
129
    EXPECT(std::distance(p.begin(), p.end()) == 2);
130
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
131
132
133
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
134
135
136
TEST_CASE(reshape_transpose)
{
    migraphx::program p;
137
138
139
140

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
    auto x   = mm->add_parameter("x", s);
141
142
143
144
    auto r1  = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 4, 28, 56, 56}}}), x);
    auto t   = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 1, 3, 4}}}), r1);
    auto ct  = mm->add_instruction(migraphx::make_op("contiguous"), t);
    auto r2  = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 112, 56, 56}}}), ct);
145
    mm->add_return({r2});
146
    EXPECT(p.get_output_shapes().back() == s);
Paul's avatar
Paul committed
147
    auto n = std::distance(p.begin(), p.end());
148
    run_pass(p);
149
    EXPECT(p.get_output_shapes().back() == s);
Paul's avatar
Paul committed
150
151
152
    EXPECT(std::distance(p.begin(), p.end()) == n);
}

Paul's avatar
Paul committed
153
154
155
TEST_CASE(transpose_contiguous)
{
    migraphx::program p;
156
157
158
159

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x   = mm->add_parameter("x", s);
160
161
    auto t   = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), x);
    auto c1  = mm->add_instruction(migraphx::make_op("contiguous"), t);
162
    mm->add_return({c1});
163
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
164
    auto n         = std::distance(p.begin(), p.end());
165
    run_pass(p);
166
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
167
168
169
170
171
172
    EXPECT(std::distance(p.begin(), p.end()) == n);
}

TEST_CASE(transpose_double_contiguous)
{
    migraphx::program p;
173
174
175
176

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x   = mm->add_parameter("x", s);
177
178
179
    auto t   = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), x);
    auto c1  = mm->add_instruction(migraphx::make_op("contiguous"), t);
    auto c2  = mm->add_instruction(migraphx::make_op("contiguous"), c1);
180
    mm->add_return({c2});
181
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
182
    auto n         = std::distance(p.begin(), p.end());
183
    run_pass(p);
184
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
185
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
186
    EXPECT(mm->has_instruction(t));
Paul's avatar
Paul committed
187
188
}

189
190
191
TEST_CASE(transpose_partial1)
{
    migraphx::program p;
192
193
194
195

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x   = mm->add_parameter("x", s);
196
197
    auto t1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), x);
    auto t2  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 2, 0}}}), t1);
198
    mm->add_return({t2});
199
    auto out_shape = p.get_output_shapes().back();
200
    auto n         = std::distance(p.begin(), p.end());
201
    run_pass(p);
202
    EXPECT(p.get_output_shapes().back() == out_shape);
203
204
205
206
207
208
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}

TEST_CASE(transpose_partial2)
{
    migraphx::program p;
209
210
211
212

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x   = mm->add_parameter("x", s);
213
214
215
    auto t1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), x);
    auto t2  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 2, 0}}}), t1);
    auto t3  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t2);
216
    mm->add_return({t3});
217
    auto out_shape = p.get_output_shapes().back();
218
    auto n         = std::distance(p.begin(), p.end());
219
    run_pass(p);
220
    EXPECT(p.get_output_shapes().back() == out_shape);
221
222
223
224
225
226
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
}

TEST_CASE(transpose_partial3)
{
    migraphx::program p;
227
228
229
230

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x   = mm->add_parameter("x", s);
231
232
233
234
    auto t1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), x);
    auto t2  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 2, 0}}}), t1);
    auto t3  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t2);
    auto t4  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t3);
235
    mm->add_return({t4});
236
    auto out_shape = p.get_output_shapes().back();
237
    auto n         = std::distance(p.begin(), p.end());
238
    run_pass(p);
239
    EXPECT(p.get_output_shapes().back() == out_shape);
240
241
242
    EXPECT(std::distance(p.begin(), p.end()) == n - 3);
}

Paul's avatar
Paul committed
243
244
245
TEST_CASE(nop_transpose1)
{
    migraphx::program p;
246
247
248
249

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x   = mm->add_parameter("x", s);
250
    auto t   = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), x);
251
    mm->add_return({t});
252
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
253
    auto n         = std::distance(p.begin(), p.end());
254
    run_pass(p);
255
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
256
257
258
259
260
261
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}

TEST_CASE(nop_transpose2)
{
    migraphx::program p;
262
263
264
265

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x   = mm->add_parameter("x", s);
266
267
268
269
    auto t1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), x);
    auto t2  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), t1);
    auto t3  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), t2);
    auto t4  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), t3);
270
    mm->add_instruction(pass_op{}, t4);
271
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
272
    auto n         = std::distance(p.begin(), p.end());
273
    run_pass(p);
274
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
275
276
277
278
279
280
    EXPECT(std::distance(p.begin(), p.end()) == n - 4);
}

TEST_CASE(nop_transpose3)
{
    migraphx::program p;
281
282

    auto* mm    = p.get_main_module();
Paul's avatar
Paul committed
283
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
284
285
    auto x      = mm->add_parameter("x", s);
    auto y      = mm->add_parameter("y", s);
286
287
288
    auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), x, y);
    auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2, 3}}}), concat);
    auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t1);
289
    mm->add_return({t2});
290
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
291
    auto n         = std::distance(p.begin(), p.end());
292
    run_pass(p);
293
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
294
295
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
Paul Fultz II's avatar
Paul Fultz II committed
296
297
298
299

TEST_CASE(nop_convert)
{
    migraphx::program p;
300
301
302
303

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x   = mm->add_parameter("x", s);
304
305
306
307
    auto t   = mm->add_instruction(
        migraphx::make_op("convert",
                          {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
        x);
308
    mm->add_return({t});
Paul Fultz II's avatar
Paul Fultz II committed
309
310
311
312
313
314
    auto out_shape = p.get_output_shapes().back();
    auto n         = std::distance(p.begin(), p.end());
    run_pass(p);
    EXPECT(p.get_output_shapes().back() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
Paul's avatar
Paul committed
315
316
317
318

TEST_CASE(concat_transpose1)
{
    migraphx::program p;
319
320

    auto* mm    = p.get_main_module();
Paul's avatar
Paul committed
321
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
322
323
    auto x      = mm->add_parameter("x", s);
    auto y      = mm->add_parameter("y", s);
324
325
326
327
    auto xt     = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), x);
    auto yt     = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), y);
    auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), xt, yt);
    auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), concat);
328
    mm->add_return({t});
329
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
330
    auto n         = std::distance(p.begin(), p.end());
331
    run_pass(p);
332
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Paul's avatar
Paul committed
333
    EXPECT(std::distance(p.begin(), p.end()) == n - 3);
Paul's avatar
Paul committed
334
335
    auto new_concat =
        std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
Paul's avatar
Paul committed
336
337
338
339
    EXPECT(bool{new_concat != p.end()});
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
}

Paul's avatar
Paul committed
340
341
342
TEST_CASE(concat_transpose2)
{
    migraphx::program p;
343
344

    auto* mm    = p.get_main_module();
Paul's avatar
Paul committed
345
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
346
347
    auto x      = mm->add_parameter("x", s);
    auto y      = mm->add_parameter("y", s);
348
349
350
351
    auto xt     = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), x);
    auto yt     = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), y);
    auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", -1}}), xt, yt);
    auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat);
352
    mm->add_return({t});
353
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
354
    auto n         = std::distance(p.begin(), p.end());
355
    run_pass(p);
356
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Paul's avatar
Paul committed
357
358
359
360
361
362
363
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    auto new_concat =
        std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
    EXPECT(bool{new_concat != p.end()});
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}

364
365
366
TEST_CASE(concat_transpose3)
{
    migraphx::program p;
367
368
369
370
371

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x   = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}});
    auto y   = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 4}});
372
373
374
375
    auto xt  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), x);
    auto yt  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), y);
    auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), xt, yt);
    auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat);
376
    mm->add_return({t});
377
    auto out_shape = p.get_output_shapes().back();
378
    auto n         = std::distance(p.begin(), p.end());
379
    run_pass(p);
380
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
381
382
383
384
385
386
387
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    auto new_concat =
        std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
    EXPECT(bool{new_concat != p.end()});
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}

Shucai Xiao's avatar
Shucai Xiao committed
388
389
390
TEST_CASE(concat_transpose4)
{
    migraphx::program p;
391
    auto* mm    = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
392
393
    auto sx     = migraphx::shape{migraphx::shape::float_type, {1, 1, 12, 64}};
    auto sy     = migraphx::shape{migraphx::shape::float_type, {1, 12, 1, 64}};
394
395
    auto x      = mm->add_parameter("x", sx);
    auto y      = mm->add_parameter("y", sy);
396
397
398
399
    auto xt     = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), x);
    auto yt     = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), y);
    auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), xt, yt);
    auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat);
400
    mm->add_return({t});
Shucai Xiao's avatar
Shucai Xiao committed
401
402
403

    migraphx::program p1 = p;
    run_pass(p);
404

Shucai Xiao's avatar
Shucai Xiao committed
405
406
407
    EXPECT(p1 == p);
}

Paul Fultz II's avatar
Paul Fultz II committed
408
409
410
TEST_CASE(nested_concat)
{
    migraphx::program p;
411
412

    auto* mm     = p.get_main_module();
Paul Fultz II's avatar
Paul Fultz II committed
413
    auto s       = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
414
415
    auto x       = mm->add_parameter("x", s);
    auto y       = mm->add_parameter("y", s);
416
417
418
419
    auto concat1 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y);
    auto concat2 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), y, x);
    auto concat3 =
        mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), concat1, concat2);
420
    mm->add_return({concat3});
421
    auto out_shape = p.get_output_shapes().back();
Paul Fultz II's avatar
Paul Fultz II committed
422
    auto n         = std::distance(p.begin(), p.end());
423
    run_pass(p);
424
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Paul Fultz II's avatar
Paul Fultz II committed
425
426
427
428
429
430
431
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
}

TEST_CASE(nested_concat_partial)
{
    migraphx::program p;
432
433
434
435
436
437

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x   = mm->add_parameter("x", s);
    auto y   = mm->add_parameter("y", s);
    auto l   = mm->add_literal(
Paul Fultz II's avatar
Paul Fultz II committed
438
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1, 4, 3, 4}}));
439
440
441
442
    auto concat1 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y);
    auto concat2 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), y, x);
    auto concat3 =
        mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), concat1, concat2, l);
443
    mm->add_return({concat3});
444
    auto out_shape = p.get_output_shapes().back();
Paul Fultz II's avatar
Paul Fultz II committed
445
    auto n         = std::distance(p.begin(), p.end());
446
    run_pass(p);
447
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Paul Fultz II's avatar
Paul Fultz II committed
448
449
450
451
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
}

452
453
454
TEST_CASE(multibroadcast_simplify)
{
    migraphx::program p;
455
456

    auto* mm = p.get_main_module();
457
458
    std::vector<size_t> s_lens{1, 2, 3, 4};
    auto s = migraphx::shape{migraphx::shape::float_type, s_lens};
459
    auto x = mm->add_parameter("x", s);
460
461
    auto y = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s_lens}}), x);
    mm->add_instruction(migraphx::make_op("mul"), y, y);
462
463
464
465
466
    auto n = std::distance(p.begin(), p.end());
    run_pass(p);
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}

467
468
469
TEST_CASE(double_slice1)
{
    migraphx::program p1;
470
    auto* mm1 = p1.get_main_module();
471
    {
472
        auto x      = mm1->add_parameter("x", {migraphx::shape::int32_type, {256}});
473
474
475
476
        auto slice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {256}}}), x);
        auto slice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {64}}}), slice1);
477
        mm1->add_return({slice2});
478
479
480
481
    }
    run_pass(p1);

    migraphx::program p2;
482
    auto* mm2 = p2.get_main_module();
483
    {
484
        auto x     = mm2->add_parameter("x", {migraphx::shape::int32_type, {256}});
485
486
        auto slice = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {96}}}), x);
487
        mm2->add_return({slice});
488
489
490
491
492
493
494
    }
    EXPECT(p1 == p2);
}

TEST_CASE(double_slice2)
{
    migraphx::program p1;
495
    auto* mm1 = p1.get_main_module();
496
    {
497
        auto x      = mm1->add_parameter("x", {migraphx::shape::int32_type, {256}});
498
499
500
501
        auto slice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {128}}}), x);
        auto slice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {32}}}), slice1);
502
        mm1->add_return({slice2});
503
504
505
506
    }
    run_pass(p1);

    migraphx::program p2;
507
    auto* mm2 = p2.get_main_module();
508
    {
509
        auto x     = mm2->add_parameter("x", {migraphx::shape::int32_type, {256}});
510
511
        auto slice = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {64}}}), x);
512
        mm2->add_return({slice});
513
514
515
516
517
518
519
    }
    EXPECT(p1 == p2);
}

TEST_CASE(double_slice_multi_axes)
{
    migraphx::program p1;
520
    auto* mm1 = p1.get_main_module();
521
    {
522
        auto x      = mm1->add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
523
524
525
526
        auto slice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {128}}}), x);
        auto slice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {32}}}), slice1);
527
        mm1->add_return({slice2});
528
529
530
531
    }
    run_pass(p1);

    migraphx::program p2;
532
533

    auto* mm2 = p2.get_main_module();
534
    {
535
        auto x     = mm2->add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
536
537
538
539
        auto slice = mm2->add_instruction(
            migraphx::make_op("slice",
                              {{"axes", {0, 1}}, {"starts", {32, 0}}, {"ends", {128, 32}}}),
            x);
540
        mm2->add_return({slice});
541
542
543
544
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
545
int main(int argc, const char* argv[]) { test::run(argc, argv); }