simplify_reshapes_test.cpp 18.3 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/pass_manager.hpp>
Paul's avatar
Paul committed
4
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/instruction.hpp>
Paul Fultz II's avatar
Paul Fultz II committed
6
#include <migraphx/generate.hpp>
Paul's avatar
Paul committed
7
8
9
#include <basic_ops.hpp>
#include <test.hpp>

10
void run_pass(migraphx::program& p)
Paul's avatar
Paul committed
11
{
12
13
    migraphx::run_passes(p, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}});
}
Paul's avatar
Paul committed
14

Paul's avatar
Paul committed
15
TEST_CASE(double_contig)
Paul's avatar
Paul committed
16
{
Paul's avatar
Paul committed
17
    migraphx::program p;
Paul's avatar
Paul committed
18
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
19
20
21
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
Shucai Xiao's avatar
Shucai Xiao committed
22
    p.add_return({c2});
23
24
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
25
    run_pass(p);
26
27
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
28
    EXPECT(std::distance(p.begin(), p.end()) == 4);
29
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
30
    EXPECT(result != get_2x2());
Paul's avatar
Paul committed
31
32
}

Paul's avatar
Paul committed
33
TEST_CASE(double_transpose)
Paul's avatar
Paul committed
34
{
Paul's avatar
Paul committed
35
    migraphx::program p;
Paul's avatar
Paul committed
36
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
37
38
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
Shucai Xiao's avatar
Shucai Xiao committed
39
    p.add_return({t2});
40
41
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
42
    run_pass(p);
43
44
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
45
    EXPECT(std::distance(p.begin(), p.end()) == 2);
46
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
47
48
49
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
50
TEST_CASE(double_transpose_contig)
Paul's avatar
Paul committed
51
{
Paul's avatar
Paul committed
52
    migraphx::program p;
Paul's avatar
Paul committed
53
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
54
55
56
57
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, c1);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, t2);
Shucai Xiao's avatar
Shucai Xiao committed
58
    p.add_return({c2});
59
60
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
61
    run_pass(p);
62
63
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
64
    EXPECT(std::distance(p.begin(), p.end()) == 2);
65
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
66
67
68
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
69
TEST_CASE(single_transpose)
Paul's avatar
Paul committed
70
{
Paul's avatar
Paul committed
71
    migraphx::program p;
Paul's avatar
Paul committed
72
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
73
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
Shucai Xiao's avatar
Shucai Xiao committed
74
    p.add_return({t1});
75
76
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
77
    run_pass(p);
78
79
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
80
    EXPECT(std::distance(p.begin(), p.end()) == 3);
81
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
82
83
84
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
85
TEST_CASE(double_transpose_sin_pass)
Paul's avatar
Paul committed
86
{
Paul's avatar
Paul committed
87
    migraphx::program p;
Paul's avatar
Paul committed
88
    auto l  = p.add_literal(get_2x2());
Paul's avatar
Paul committed
89
90
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
    p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
91
92
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
93
    run_pass(p);
94
95
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
96
97
    // TODO: Fix this
    // EXPECT(std::distance(p.begin(), p.end()) == 1);
98
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
99
100
101
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
102
TEST_CASE(single_transpose_sin_pass)
Paul's avatar
Paul committed
103
{
Paul's avatar
Paul committed
104
    migraphx::program p;
Paul's avatar
Paul committed
105
    auto l = p.add_literal(get_2x2());
Paul's avatar
Paul committed
106
    p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
107
108
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
109
    run_pass(p);
110
111
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
112
    EXPECT(std::distance(p.begin(), p.end()) == 2);
113
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
114
115
116
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
117
118
119
TEST_CASE(reshape_transpose)
{
    migraphx::program p;
Paul's avatar
Paul committed
120
121
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
    auto x  = p.add_parameter("x", s);
Paul's avatar
Paul committed
122
    auto r1 = p.add_instruction(migraphx::op::reshape{{1, 4, 28, 56, 56}}, x);
Paul's avatar
Paul committed
123
    auto t  = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3, 4}}, r1);
Paul's avatar
Paul committed
124
125
    auto ct = p.add_instruction(migraphx::op::contiguous{}, t);
    auto r2 = p.add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct);
Shucai Xiao's avatar
Shucai Xiao committed
126
    p.add_return({r2});
127
    EXPECT(p.get_output_shapes().back() == s);
Paul's avatar
Paul committed
128
    auto n = std::distance(p.begin(), p.end());
129
    run_pass(p);
130
    EXPECT(p.get_output_shapes().back() == s);
Paul's avatar
Paul committed
131
132
133
    EXPECT(std::distance(p.begin(), p.end()) == n);
}

Paul's avatar
Paul committed
134
135
136
137
138
139
140
TEST_CASE(transpose_contiguous)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x  = p.add_parameter("x", s);
    auto t  = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
Shucai Xiao's avatar
Shucai Xiao committed
141
    p.add_return({c1});
142
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
143
    auto n         = std::distance(p.begin(), p.end());
144
    run_pass(p);
145
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
146
147
148
149
150
151
152
153
154
155
156
    EXPECT(std::distance(p.begin(), p.end()) == n);
}

TEST_CASE(transpose_double_contiguous)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x  = p.add_parameter("x", s);
    auto t  = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
    auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
    auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
Shucai Xiao's avatar
Shucai Xiao committed
157
    p.add_return({c2});
158
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
159
    auto n         = std::distance(p.begin(), p.end());
160
    run_pass(p);
161
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
162
163
164
165
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
    EXPECT(p.has_instruction(t));
}

166
167
168
169
170
TEST_CASE(transpose_partial1)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x  = p.add_parameter("x", s);
Paul's avatar
Paul committed
171
172
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
Shucai Xiao's avatar
Shucai Xiao committed
173
    p.add_return({t2});
174
    auto out_shape = p.get_output_shapes().back();
175
    auto n         = std::distance(p.begin(), p.end());
176
    run_pass(p);
177
    EXPECT(p.get_output_shapes().back() == out_shape);
178
179
180
181
182
183
184
185
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}

TEST_CASE(transpose_partial2)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x  = p.add_parameter("x", s);
Paul's avatar
Paul committed
186
187
188
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
    auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
Shucai Xiao's avatar
Shucai Xiao committed
189
    p.add_return({t3});
190
    auto out_shape = p.get_output_shapes().back();
191
    auto n         = std::distance(p.begin(), p.end());
192
    run_pass(p);
193
    EXPECT(p.get_output_shapes().back() == out_shape);
194
195
196
197
198
199
200
201
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
}

TEST_CASE(transpose_partial3)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x  = p.add_parameter("x", s);
Paul's avatar
Paul committed
202
203
204
205
    auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
    auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
    auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
    auto t4 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3);
Shucai Xiao's avatar
Shucai Xiao committed
206
    p.add_return({t4});
207
    auto out_shape = p.get_output_shapes().back();
208
    auto n         = std::distance(p.begin(), p.end());
209
    run_pass(p);
210
    EXPECT(p.get_output_shapes().back() == out_shape);
211
212
213
    EXPECT(std::distance(p.begin(), p.end()) == n - 3);
}

Paul's avatar
Paul committed
214
215
216
TEST_CASE(nop_transpose1)
{
    migraphx::program p;
Paul's avatar
Paul committed
217
218
    auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x = p.add_parameter("x", s);
Paul's avatar
Paul committed
219
    auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
Shucai Xiao's avatar
Shucai Xiao committed
220
    p.add_return({t});
221
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
222
    auto n         = std::distance(p.begin(), p.end());
223
    run_pass(p);
224
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
225
226
227
228
229
230
231
232
233
234
235
236
237
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}

TEST_CASE(nop_transpose2)
{
    migraphx::program p;
    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x  = p.add_parameter("x", s);
    auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
    auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t1);
    auto t3 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t2);
    auto t4 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t3);
    p.add_instruction(pass_op{}, t4);
238
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
239
    auto n         = std::distance(p.begin(), p.end());
240
    run_pass(p);
241
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
242
243
244
245
246
247
    EXPECT(std::distance(p.begin(), p.end()) == n - 4);
}

TEST_CASE(nop_transpose3)
{
    migraphx::program p;
Paul's avatar
Paul committed
248
249
250
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x      = p.add_parameter("x", s);
    auto y      = p.add_parameter("y", s);
Paul's avatar
Paul committed
251
    auto concat = p.add_instruction(migraphx::op::concat{3}, x, y);
Paul's avatar
Paul committed
252
253
    auto t1     = p.add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat);
    auto t2     = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1);
Shucai Xiao's avatar
Shucai Xiao committed
254
    p.add_return({t2});
255
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
256
    auto n         = std::distance(p.begin(), p.end());
257
    run_pass(p);
258
    EXPECT(p.get_output_shapes().back() == out_shape);
Paul's avatar
Paul committed
259
260
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
Paul Fultz II's avatar
Paul Fultz II committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274

TEST_CASE(nop_convert)
{
    migraphx::program p;
    auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x = p.add_parameter("x", s);
    auto t = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, x);
    p.add_return({t});
    auto out_shape = p.get_output_shapes().back();
    auto n         = std::distance(p.begin(), p.end());
    run_pass(p);
    EXPECT(p.get_output_shapes().back() == out_shape);
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
Paul's avatar
Paul committed
275
276
277
278

TEST_CASE(concat_transpose1)
{
    migraphx::program p;
Paul's avatar
Paul committed
279
280
281
282
283
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x      = p.add_parameter("x", s);
    auto y      = p.add_parameter("y", s);
    auto xt     = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
    auto yt     = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
Paul's avatar
Paul committed
284
    auto concat = p.add_instruction(migraphx::op::concat{2}, xt, yt);
Paul's avatar
Paul committed
285
    auto t      = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat);
Shucai Xiao's avatar
Shucai Xiao committed
286
    p.add_return({t});
287
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
288
    auto n         = std::distance(p.begin(), p.end());
289
    run_pass(p);
290
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Paul's avatar
Paul committed
291
    EXPECT(std::distance(p.begin(), p.end()) == n - 3);
Paul's avatar
Paul committed
292
293
    auto new_concat =
        std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
Paul's avatar
Paul committed
294
295
296
297
    EXPECT(bool{new_concat != p.end()});
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
}

Paul's avatar
Paul committed
298
299
300
301
302
303
304
305
TEST_CASE(concat_transpose2)
{
    migraphx::program p;
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x      = p.add_parameter("x", s);
    auto y      = p.add_parameter("y", s);
    auto xt     = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
    auto yt     = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
Shucai Xiao's avatar
Shucai Xiao committed
306
    auto concat = p.add_instruction(migraphx::op::concat{-1}, xt, yt);
Paul's avatar
Paul committed
307
    auto t      = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
Shucai Xiao's avatar
Shucai Xiao committed
308
    p.add_return({t});
309
    auto out_shape = p.get_output_shapes().back();
Paul's avatar
Paul committed
310
    auto n         = std::distance(p.begin(), p.end());
311
    run_pass(p);
312
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Paul's avatar
Paul committed
313
314
315
316
317
318
319
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    auto new_concat =
        std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
    EXPECT(bool{new_concat != p.end()});
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}

320
321
322
323
324
325
326
327
328
329
TEST_CASE(concat_transpose3)
{
    migraphx::program p;
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x      = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}});
    auto y      = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 4}});
    auto xt     = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
    auto yt     = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
    auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
    auto t      = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
Shucai Xiao's avatar
Shucai Xiao committed
330
    p.add_return({t});
331
    auto out_shape = p.get_output_shapes().back();
332
    auto n         = std::distance(p.begin(), p.end());
333
    run_pass(p);
334
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
335
336
337
338
339
340
341
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    auto new_concat =
        std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
    EXPECT(bool{new_concat != p.end()});
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}

Shucai Xiao's avatar
Shucai Xiao committed
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
TEST_CASE(concat_transpose4)
{
    migraphx::program p;
    auto sx     = migraphx::shape{migraphx::shape::float_type, {1, 1, 12, 64}};
    auto sy     = migraphx::shape{migraphx::shape::float_type, {1, 12, 1, 64}};
    auto x      = p.add_parameter("x", sx);
    auto y      = p.add_parameter("y", sy);
    auto xt     = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
    auto yt     = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
    auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
    auto t      = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
    p.add_return({t});

    migraphx::program p1 = p;

    run_pass(p);
    EXPECT(p1 == p);
}

Paul Fultz II's avatar
Paul Fultz II committed
361
362
363
364
365
366
367
368
369
TEST_CASE(nested_concat)
{
    migraphx::program p;
    auto s       = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x       = p.add_parameter("x", s);
    auto y       = p.add_parameter("y", s);
    auto concat1 = p.add_instruction(migraphx::op::concat{1}, x, y);
    auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x);
    auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2);
Shucai Xiao's avatar
Shucai Xiao committed
370
    p.add_return({concat3});
371
    auto out_shape = p.get_output_shapes().back();
Paul Fultz II's avatar
Paul Fultz II committed
372
    auto n         = std::distance(p.begin(), p.end());
373
    run_pass(p);
374
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Paul Fultz II's avatar
Paul Fultz II committed
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
}

TEST_CASE(nested_concat_partial)
{
    migraphx::program p;
    auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x = p.add_parameter("x", s);
    auto y = p.add_parameter("y", s);
    auto l = p.add_literal(
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1, 4, 3, 4}}));
    auto concat1 = p.add_instruction(migraphx::op::concat{1}, x, y);
    auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x);
    auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2, l);
Shucai Xiao's avatar
Shucai Xiao committed
390
    p.add_return({concat3});
391
    auto out_shape = p.get_output_shapes().back();
Paul Fultz II's avatar
Paul Fultz II committed
392
    auto n         = std::distance(p.begin(), p.end());
393
    run_pass(p);
394
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Paul Fultz II's avatar
Paul Fultz II committed
395
396
397
398
    EXPECT(std::distance(p.begin(), p.end()) == n - 2);
    EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
}

399
400
401
402
403
404
405
406
407
408
409
410
411
TEST_CASE(multibroadcast_simplify)
{
    migraphx::program p;
    std::vector<size_t> s_lens{1, 2, 3, 4};
    auto s = migraphx::shape{migraphx::shape::float_type, s_lens};
    auto x = p.add_parameter("x", s);
    auto y = p.add_instruction(migraphx::op::multibroadcast{s_lens}, x);
    p.add_instruction(migraphx::op::mul{}, y, y);
    auto n = std::distance(p.begin(), p.end());
    run_pass(p);
    EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}

412
413
414
415
416
417
418
TEST_CASE(double_slice1)
{
    migraphx::program p1;
    {
        auto x      = p1.add_parameter("x", {migraphx::shape::int32_type, {256}});
        auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {256}}, x);
        auto slice2 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {64}}, slice1);
Shucai Xiao's avatar
Shucai Xiao committed
419
        p1.add_return({slice2});
420
421
422
423
424
425
426
    }
    run_pass(p1);

    migraphx::program p2;
    {
        auto x     = p2.add_parameter("x", {migraphx::shape::int32_type, {256}});
        auto slice = p2.add_instruction(migraphx::op::slice{{0}, {64}, {96}}, x);
Shucai Xiao's avatar
Shucai Xiao committed
427
        p2.add_return({slice});
428
429
430
431
432
433
434
435
436
437
438
    }
    EXPECT(p1 == p2);
}

TEST_CASE(double_slice2)
{
    migraphx::program p1;
    {
        auto x      = p1.add_parameter("x", {migraphx::shape::int32_type, {256}});
        auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x);
        auto slice2 = p1.add_instruction(migraphx::op::slice{{0}, {0}, {32}}, slice1);
Shucai Xiao's avatar
Shucai Xiao committed
439
        p1.add_return({slice2});
440
441
442
443
444
445
446
    }
    run_pass(p1);

    migraphx::program p2;
    {
        auto x     = p2.add_parameter("x", {migraphx::shape::int32_type, {256}});
        auto slice = p2.add_instruction(migraphx::op::slice{{0}, {32}, {64}}, x);
Shucai Xiao's avatar
Shucai Xiao committed
447
        p2.add_return({slice});
448
449
450
451
452
453
454
455
456
457
458
    }
    EXPECT(p1 == p2);
}

TEST_CASE(double_slice_multi_axes)
{
    migraphx::program p1;
    {
        auto x      = p1.add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
        auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x);
        auto slice2 = p1.add_instruction(migraphx::op::slice{{1}, {0}, {32}}, slice1);
Shucai Xiao's avatar
Shucai Xiao committed
459
        p1.add_return({slice2});
460
461
462
463
464
465
466
    }
    run_pass(p1);

    migraphx::program p2;
    {
        auto x     = p2.add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
        auto slice = p2.add_instruction(migraphx::op::slice{{0, 1}, {32, 0}, {128, 32}}, x);
Shucai Xiao's avatar
Shucai Xiao committed
467
        p2.add_return({slice});
468
469
470
471
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
472
int main(int argc, const char* argv[]) { test::run(argc, argv); }