simplify_reshapes_test.cpp 20.3 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/pass_manager.hpp>
Paul's avatar
Paul committed
4
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/instruction.hpp>
Paul Fultz II's avatar
Paul Fultz II committed
6
#include <migraphx/generate.hpp>
Paul's avatar
Paul committed
7
#include <basic_ops.hpp>
8
9
10
11
#include <migraphx/make_op.hpp>

#include <migraphx/serialize.hpp>

Paul's avatar
Paul committed
12
13
#include <test.hpp>

Paul Fultz II's avatar
Paul Fultz II committed
14
void run_pass(migraphx::module& m)
Paul's avatar
Paul committed
15
{
Paul Fultz II's avatar
Paul Fultz II committed
16
    migraphx::run_passes(m, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}});
17
}
Paul's avatar
Paul committed
18

Paul's avatar
Paul committed
19
TEST_CASE(double_contig)
Paul's avatar
Paul committed
20
{
Paul's avatar
Paul committed
21
    migraphx::program p;
22
    auto* mm = p.get_main_module();
Paul Fultz II's avatar
Paul Fultz II committed
23
24
25
26
27

    auto l  = mm->add_literal(get_2x2());
    auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
    auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), t1);
    auto c2 = mm->add_instruction(migraphx::make_op("contiguous"), c1);
28
    mm->add_return({c2});
Paul Fultz II's avatar
Paul Fultz II committed
29
30
31
32
33
    EXPECT(mm->get_output_shapes().back().standard());
    EXPECT(not mm->get_output_shapes().back().transposed());
    run_pass(*mm);
    EXPECT(mm->get_output_shapes().back().standard());
    EXPECT(not mm->get_output_shapes().back().transposed());
Shucai Xiao's avatar
Shucai Xiao committed
34
    EXPECT(std::distance(mm->begin(), mm->end()) == 4);
35
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
36
    EXPECT(result != get_2x2());
Paul's avatar
Paul committed
37
38
}

Paul's avatar
Paul committed
39
TEST_CASE(double_transpose)
Paul's avatar
Paul committed
40
{
Paul's avatar
Paul committed
41
    migraphx::program p;
42
    auto* mm = p.get_main_module();
Paul Fultz II's avatar
Paul Fultz II committed
43
44
45
46

    auto l  = mm->add_literal(get_2x2());
    auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
    auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t1);
47
    mm->add_return({t2});
Paul Fultz II's avatar
Paul Fultz II committed
48
49
50
51
52
    EXPECT(mm->get_output_shapes().back().standard());
    EXPECT(not mm->get_output_shapes().back().transposed());
    run_pass(*mm);
    EXPECT(mm->get_output_shapes().back().standard());
    EXPECT(not mm->get_output_shapes().back().transposed());
Shucai Xiao's avatar
Shucai Xiao committed
53
    EXPECT(std::distance(mm->begin(), mm->end()) == 2);
54
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
55
56
57
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
58
TEST_CASE(double_transpose_contig)
Paul's avatar
Paul committed
59
{
Paul's avatar
Paul committed
60
    migraphx::program p;
61
    auto* mm = p.get_main_module();
Paul Fultz II's avatar
Paul Fultz II committed
62
63
64
65
66
67

    auto l  = mm->add_literal(get_2x2());
    auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
    auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), t1);
    auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), c1);
    auto c2 = mm->add_instruction(migraphx::make_op("contiguous"), t2);
68
    mm->add_return({c2});
Paul Fultz II's avatar
Paul Fultz II committed
69
70
71
72
73
    EXPECT(mm->get_output_shapes().back().standard());
    EXPECT(not mm->get_output_shapes().back().transposed());
    run_pass(*mm);
    EXPECT(mm->get_output_shapes().back().standard());
    EXPECT(not mm->get_output_shapes().back().transposed());
Shucai Xiao's avatar
Shucai Xiao committed
74
    EXPECT(std::distance(mm->begin(), mm->end()) == 2);
75
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
76
77
78
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
79
TEST_CASE(single_transpose)
Paul's avatar
Paul committed
80
{
Paul's avatar
Paul committed
81
    migraphx::program p;
82
    auto* mm = p.get_main_module();
Paul Fultz II's avatar
Paul Fultz II committed
83
84
85

    auto l  = mm->add_literal(get_2x2());
    auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
86
    mm->add_return({t1});
Paul Fultz II's avatar
Paul Fultz II committed
87
88
89
90
91
    EXPECT(not mm->get_output_shapes().back().standard());
    EXPECT(mm->get_output_shapes().back().transposed());
    run_pass(*mm);
    EXPECT(not mm->get_output_shapes().back().standard());
    EXPECT(mm->get_output_shapes().back().transposed());
Shucai Xiao's avatar
Shucai Xiao committed
92
    EXPECT(std::distance(mm->begin(), mm->end()) == 3);
93
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
94
95
96
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
97
TEST_CASE(double_transpose_sin_pass)
Paul's avatar
Paul committed
98
{
Paul's avatar
Paul committed
99
    migraphx::program p;
100
    auto* mm = p.get_main_module();
Paul Fultz II's avatar
Paul Fultz II committed
101
102
103

    auto l  = mm->add_literal(get_2x2());
    auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
104
    mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t1);
Paul Fultz II's avatar
Paul Fultz II committed
105
106
107
108
109
    EXPECT(mm->get_output_shapes().back().standard());
    EXPECT(not mm->get_output_shapes().back().transposed());
    run_pass(*mm);
    EXPECT(mm->get_output_shapes().back().standard());
    EXPECT(not mm->get_output_shapes().back().transposed());
Paul's avatar
Paul committed
110
    // TODO: Fix this
Shucai Xiao's avatar
Shucai Xiao committed
111
    // EXPECT(std::distance(mm->begin(), mm->end()) == 1);
112
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
113
114
115
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
116
TEST_CASE(single_transpose_sin_pass)
Paul's avatar
Paul committed
117
{
Paul's avatar
Paul committed
118
    migraphx::program p;
119
    auto* mm = p.get_main_module();
Paul Fultz II's avatar
Paul Fultz II committed
120
121

    auto l = mm->add_literal(get_2x2());
122
    mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
Paul Fultz II's avatar
Paul Fultz II committed
123
124
125
126
127
    EXPECT(not mm->get_output_shapes().back().standard());
    EXPECT(mm->get_output_shapes().back().transposed());
    run_pass(*mm);
    EXPECT(not mm->get_output_shapes().back().standard());
    EXPECT(mm->get_output_shapes().back().transposed());
Shucai Xiao's avatar
Shucai Xiao committed
128
    EXPECT(std::distance(mm->begin(), mm->end()) == 2);
129
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
130
131
132
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
133
134
TEST_CASE(reshape_transpose)
{
Paul Fultz II's avatar
Paul Fultz II committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    migraphx::module m;

    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
    auto x  = m.add_parameter("x", s);
    auto r1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 4, 28, 56, 56}}}), x);
    auto t  = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 1, 3, 4}}}), r1);
    auto ct = m.add_instruction(migraphx::make_op("contiguous"), t);
    auto r2 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 112, 56, 56}}}), ct);
    m.add_return({r2});
    EXPECT(m.get_output_shapes().back() == s);
    auto n = std::distance(m.begin(), m.end());
    run_pass(m);
    EXPECT(m.get_output_shapes().back() == s);
    EXPECT(std::distance(m.begin(), m.end()) == n);
Paul's avatar
Paul committed
149
150
}

Paul's avatar
Paul committed
151
152
TEST_CASE(transpose_contiguous)
{
Paul Fultz II's avatar
Paul Fultz II committed
153
154
155
156
157
158
159
160
161
162
163
164
    migraphx::module m;

    auto s  = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x  = m.add_parameter("x", s);
    auto t  = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), x);
    auto c1 = m.add_instruction(migraphx::make_op("contiguous"), t);
    m.add_return({c1});
    auto out_shape = m.get_output_shapes().back();
    auto n         = std::distance(m.begin(), m.end());
    run_pass(m);
    EXPECT(m.get_output_shapes().back() == out_shape);
    EXPECT(std::distance(m.begin(), m.end()) == n);
Paul's avatar
Paul committed
165
166
167
168
}

TEST_CASE(transpose_double_contiguous)
{
Paul Fultz II's avatar
Paul Fultz II committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    migraphx::module m;

    auto s  = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x  = m.add_parameter("x", s);
    auto t  = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), x);
    auto c1 = m.add_instruction(migraphx::make_op("contiguous"), t);
    auto c2 = m.add_instruction(migraphx::make_op("contiguous"), c1);
    m.add_return({c2});
    auto out_shape = m.get_output_shapes().back();
    auto n         = std::distance(m.begin(), m.end());
    run_pass(m);
    EXPECT(m.get_output_shapes().back() == out_shape);
    EXPECT(std::distance(m.begin(), m.end()) == n - 1);
    EXPECT(m.has_instruction(t));
Paul's avatar
Paul committed
183
184
}

185
186
TEST_CASE(transpose_partial1)
{
Paul Fultz II's avatar
Paul Fultz II committed
187
188
189
190
191
192
193
194
195
196
197
198
    migraphx::module m;

    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x  = m.add_parameter("x", s);
    auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), x);
    auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 2, 0}}}), t1);
    m.add_return({t2});
    auto out_shape = m.get_output_shapes().back();
    auto n         = std::distance(m.begin(), m.end());
    run_pass(m);
    EXPECT(m.get_output_shapes().back() == out_shape);
    EXPECT(std::distance(m.begin(), m.end()) == n - 1);
199
200
201
202
}

TEST_CASE(transpose_partial2)
{
Paul Fultz II's avatar
Paul Fultz II committed
203
204
205
206
207
208
209
210
211
212
213
214
215
    migraphx::module m;

    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x  = m.add_parameter("x", s);
    auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), x);
    auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 2, 0}}}), t1);
    auto t3 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t2);
    m.add_return({t3});
    auto out_shape = m.get_output_shapes().back();
    auto n         = std::distance(m.begin(), m.end());
    run_pass(m);
    EXPECT(m.get_output_shapes().back() == out_shape);
    EXPECT(std::distance(m.begin(), m.end()) == n - 2);
216
217
218
219
}

TEST_CASE(transpose_partial3)
{
Paul Fultz II's avatar
Paul Fultz II committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    migraphx::module m;

    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x  = m.add_parameter("x", s);
    auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), x);
    auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 2, 0}}}), t1);
    auto t3 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t2);
    auto t4 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t3);
    m.add_return({t4});
    auto out_shape = m.get_output_shapes().back();
    auto n         = std::distance(m.begin(), m.end());
    run_pass(m);
    EXPECT(m.get_output_shapes().back() == out_shape);
    EXPECT(std::distance(m.begin(), m.end()) == n - 3);
234
235
}

Paul's avatar
Paul committed
236
237
TEST_CASE(nop_transpose1)
{
Paul Fultz II's avatar
Paul Fultz II committed
238
239
240
241
242
243
244
245
246
247
248
    migraphx::module m;

    auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x = m.add_parameter("x", s);
    auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), x);
    m.add_return({t});
    auto out_shape = m.get_output_shapes().back();
    auto n         = std::distance(m.begin(), m.end());
    run_pass(m);
    EXPECT(m.get_output_shapes().back() == out_shape);
    EXPECT(std::distance(m.begin(), m.end()) == n - 1);
Paul's avatar
Paul committed
249
250
251
252
}

TEST_CASE(nop_transpose2)
{
Paul Fultz II's avatar
Paul Fultz II committed
253
254
255
256
257
258
259
260
261
262
263
264
265
266
    migraphx::module m;

    auto s  = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x  = m.add_parameter("x", s);
    auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), x);
    auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), t1);
    auto t3 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), t2);
    auto t4 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), t3);
    m.add_instruction(pass_op{}, t4);
    auto out_shape = m.get_output_shapes().back();
    auto n         = std::distance(m.begin(), m.end());
    run_pass(m);
    EXPECT(m.get_output_shapes().back() == out_shape);
    EXPECT(std::distance(m.begin(), m.end()) == n - 4);
Paul's avatar
Paul committed
267
268
269
270
}

TEST_CASE(nop_transpose3)
{
Paul Fultz II's avatar
Paul Fultz II committed
271
    migraphx::module m;
272

Paul's avatar
Paul committed
273
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
Paul Fultz II's avatar
Paul Fultz II committed
274
275
276
277
278
279
280
281
282
283
284
    auto x      = m.add_parameter("x", s);
    auto y      = m.add_parameter("y", s);
    auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), x, y);
    auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2, 3}}}), concat);
    auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t1);
    m.add_return({t2});
    auto out_shape = m.get_output_shapes().back();
    auto n         = std::distance(m.begin(), m.end());
    run_pass(m);
    EXPECT(m.get_output_shapes().back() == out_shape);
    EXPECT(std::distance(m.begin(), m.end()) == n - 1);
Paul's avatar
Paul committed
285
}
Paul Fultz II's avatar
Paul Fultz II committed
286
287
288

TEST_CASE(nop_convert)
{
Paul Fultz II's avatar
Paul Fultz II committed
289
    migraphx::module m;
290

Paul Fultz II's avatar
Paul Fultz II committed
291
292
293
    auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x = m.add_parameter("x", s);
    auto t = m.add_instruction(
294
295
296
        migraphx::make_op("convert",
                          {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
        x);
Paul Fultz II's avatar
Paul Fultz II committed
297
298
299
300
301
302
    m.add_return({t});
    auto out_shape = m.get_output_shapes().back();
    auto n         = std::distance(m.begin(), m.end());
    run_pass(m);
    EXPECT(m.get_output_shapes().back() == out_shape);
    EXPECT(std::distance(m.begin(), m.end()) == n - 1);
Paul Fultz II's avatar
Paul Fultz II committed
303
}
Paul's avatar
Paul committed
304
305
306

TEST_CASE(concat_transpose1)
{
Paul Fultz II's avatar
Paul Fultz II committed
307
    migraphx::module m;
308

Paul's avatar
Paul committed
309
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
Paul Fultz II's avatar
Paul Fultz II committed
310
311
312
313
314
315
316
317
318
319
320
321
    auto x      = m.add_parameter("x", s);
    auto y      = m.add_parameter("y", s);
    auto xt     = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), x);
    auto yt     = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), y);
    auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), xt, yt);
    auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), concat);
    m.add_return({t});
    auto out_shape = m.get_output_shapes().back();
    auto n         = std::distance(m.begin(), m.end());
    run_pass(m);
    EXPECT(m.get_output_shapes().back().lens() == out_shape.lens());
    EXPECT(std::distance(m.begin(), m.end()) == n - 3);
Paul's avatar
Paul committed
322
    auto new_concat =
Paul Fultz II's avatar
Paul Fultz II committed
323
324
        std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
    EXPECT(bool{new_concat != m.end()});
Paul's avatar
Paul committed
325
326
327
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
}

Paul's avatar
Paul committed
328
329
TEST_CASE(concat_transpose2)
{
Paul Fultz II's avatar
Paul Fultz II committed
330
    migraphx::module m;
331

Paul's avatar
Paul committed
332
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
Paul Fultz II's avatar
Paul Fultz II committed
333
334
335
336
337
338
339
340
341
342
343
344
    auto x      = m.add_parameter("x", s);
    auto y      = m.add_parameter("y", s);
    auto xt     = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), x);
    auto yt     = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), y);
    auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", -1}}), xt, yt);
    auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat);
    m.add_return({t});
    auto out_shape = m.get_output_shapes().back();
    auto n         = std::distance(m.begin(), m.end());
    run_pass(m);
    EXPECT(m.get_output_shapes().back().lens() == out_shape.lens());
    EXPECT(std::distance(m.begin(), m.end()) == n - 2);
Paul's avatar
Paul committed
345
    auto new_concat =
Paul Fultz II's avatar
Paul Fultz II committed
346
347
        std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
    EXPECT(bool{new_concat != m.end()});
Paul's avatar
Paul committed
348
349
350
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}

351
352
TEST_CASE(concat_transpose3)
{
Paul Fultz II's avatar
Paul Fultz II committed
353
    migraphx::module m;
354

Paul Fultz II's avatar
Paul Fultz II committed
355
356
357
358
359
360
361
362
363
364
365
366
367
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x      = m.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}});
    auto y      = m.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 4}});
    auto xt     = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), x);
    auto yt     = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), y);
    auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), xt, yt);
    auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat);
    m.add_return({t});
    auto out_shape = m.get_output_shapes().back();
    auto n         = std::distance(m.begin(), m.end());
    run_pass(m);
    EXPECT(m.get_output_shapes().back().lens() == out_shape.lens());
    EXPECT(std::distance(m.begin(), m.end()) == n - 2);
368
    auto new_concat =
Paul Fultz II's avatar
Paul Fultz II committed
369
370
        std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
    EXPECT(bool{new_concat != m.end()});
371
372
373
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}

Shucai Xiao's avatar
Shucai Xiao committed
374
375
TEST_CASE(concat_transpose4)
{
Paul Fultz II's avatar
Paul Fultz II committed
376
    migraphx::module m;
Shucai Xiao's avatar
Shucai Xiao committed
377
378
    auto sx     = migraphx::shape{migraphx::shape::float_type, {1, 1, 12, 64}};
    auto sy     = migraphx::shape{migraphx::shape::float_type, {1, 12, 1, 64}};
Paul Fultz II's avatar
Paul Fultz II committed
379
380
381
382
383
384
385
386
387
388
389
390
    auto x      = m.add_parameter("x", sx);
    auto y      = m.add_parameter("y", sy);
    auto xt     = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), x);
    auto yt     = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), y);
    auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), xt, yt);
    auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat);
    m.add_return({t});

    migraphx::module m1 = m;
    run_pass(m);

    EXPECT(m1 == m);
Shucai Xiao's avatar
Shucai Xiao committed
391
392
}

Paul Fultz II's avatar
Paul Fultz II committed
393
394
TEST_CASE(nested_concat)
{
Paul Fultz II's avatar
Paul Fultz II committed
395
    migraphx::module m;
396

Paul Fultz II's avatar
Paul Fultz II committed
397
    auto s       = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
Paul Fultz II's avatar
Paul Fultz II committed
398
399
400
401
402
403
404
405
406
407
408
409
    auto x       = m.add_parameter("x", s);
    auto y       = m.add_parameter("y", s);
    auto concat1 = m.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y);
    auto concat2 = m.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), y, x);
    auto concat3 = m.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), concat1, concat2);
    m.add_return({concat3});
    auto out_shape = m.get_output_shapes().back();
    auto n         = std::distance(m.begin(), m.end());
    run_pass(m);
    EXPECT(m.get_output_shapes().back().lens() == out_shape.lens());
    EXPECT(std::distance(m.begin(), m.end()) == n - 2);
    EXPECT(std::count_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
Paul Fultz II's avatar
Paul Fultz II committed
410
411
412
413
}

TEST_CASE(nested_concat_partial)
{
Paul Fultz II's avatar
Paul Fultz II committed
414
    migraphx::module m;
415

Paul Fultz II's avatar
Paul Fultz II committed
416
417
418
419
    auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x = m.add_parameter("x", s);
    auto y = m.add_parameter("y", s);
    auto l = m.add_literal(
Paul Fultz II's avatar
Paul Fultz II committed
420
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1, 4, 3, 4}}));
Paul Fultz II's avatar
Paul Fultz II committed
421
422
    auto concat1 = m.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y);
    auto concat2 = m.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), y, x);
423
    auto concat3 =
Paul Fultz II's avatar
Paul Fultz II committed
424
425
426
427
428
429
430
431
        m.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), concat1, concat2, l);
    m.add_return({concat3});
    auto out_shape = m.get_output_shapes().back();
    auto n         = std::distance(m.begin(), m.end());
    run_pass(m);
    EXPECT(m.get_output_shapes().back().lens() == out_shape.lens());
    EXPECT(std::distance(m.begin(), m.end()) == n - 2);
    EXPECT(std::count_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
Paul Fultz II's avatar
Paul Fultz II committed
432
433
}

434
435
TEST_CASE(multibroadcast_simplify)
{
Paul Fultz II's avatar
Paul Fultz II committed
436
    migraphx::module m;
437

438
439
    std::vector<size_t> s_lens{1, 2, 3, 4};
    auto s = migraphx::shape{migraphx::shape::float_type, s_lens};
Paul Fultz II's avatar
Paul Fultz II committed
440
441
442
443
444
445
    auto x = m.add_parameter("x", s);
    auto y = m.add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s_lens}}), x);
    m.add_instruction(migraphx::make_op("mul"), y, y);
    auto n = std::distance(m.begin(), m.end());
    run_pass(m);
    EXPECT(std::distance(m.begin(), m.end()) == n - 1);
446
447
}

448
449
TEST_CASE(double_slice1)
{
Paul Fultz II's avatar
Paul Fultz II committed
450
    migraphx::module m1;
451
    {
Paul Fultz II's avatar
Paul Fultz II committed
452
453
        auto x      = m1.add_parameter("x", {migraphx::shape::int32_type, {256}});
        auto slice1 = m1.add_instruction(
454
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {256}}}), x);
Paul Fultz II's avatar
Paul Fultz II committed
455
        auto slice2 = m1.add_instruction(
456
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {64}}}), slice1);
Paul Fultz II's avatar
Paul Fultz II committed
457
        m1.add_return({slice2});
458
    }
Paul Fultz II's avatar
Paul Fultz II committed
459
    run_pass(m1);
460

Paul Fultz II's avatar
Paul Fultz II committed
461
    migraphx::module m2;
462
    {
Paul Fultz II's avatar
Paul Fultz II committed
463
464
        auto x     = m2.add_parameter("x", {migraphx::shape::int32_type, {256}});
        auto slice = m2.add_instruction(
465
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {96}}}), x);
Paul Fultz II's avatar
Paul Fultz II committed
466
        m2.add_return({slice});
467
    }
Paul Fultz II's avatar
Paul Fultz II committed
468
    EXPECT(m1 == m2);
469
470
471
472
}

TEST_CASE(double_slice2)
{
Paul Fultz II's avatar
Paul Fultz II committed
473
    migraphx::module m1;
474
    {
Paul Fultz II's avatar
Paul Fultz II committed
475
476
        auto x      = m1.add_parameter("x", {migraphx::shape::int32_type, {256}});
        auto slice1 = m1.add_instruction(
477
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {128}}}), x);
Paul Fultz II's avatar
Paul Fultz II committed
478
        auto slice2 = m1.add_instruction(
479
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {32}}}), slice1);
Paul Fultz II's avatar
Paul Fultz II committed
480
        m1.add_return({slice2});
481
    }
Paul Fultz II's avatar
Paul Fultz II committed
482
    run_pass(m1);
483

Paul Fultz II's avatar
Paul Fultz II committed
484
    migraphx::module m2;
485
    {
Paul Fultz II's avatar
Paul Fultz II committed
486
487
        auto x     = m2.add_parameter("x", {migraphx::shape::int32_type, {256}});
        auto slice = m2.add_instruction(
488
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {64}}}), x);
Paul Fultz II's avatar
Paul Fultz II committed
489
        m2.add_return({slice});
490
    }
Paul Fultz II's avatar
Paul Fultz II committed
491
    EXPECT(m1 == m2);
492
493
494
495
}

TEST_CASE(double_slice_multi_axes)
{
Paul Fultz II's avatar
Paul Fultz II committed
496
    migraphx::module m1;
497
    {
Paul Fultz II's avatar
Paul Fultz II committed
498
499
        auto x      = m1.add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
        auto slice1 = m1.add_instruction(
500
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {128}}}), x);
Paul Fultz II's avatar
Paul Fultz II committed
501
        auto slice2 = m1.add_instruction(
502
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {32}}}), slice1);
Paul Fultz II's avatar
Paul Fultz II committed
503
        m1.add_return({slice2});
504
    }
Paul Fultz II's avatar
Paul Fultz II committed
505
    run_pass(m1);
506

Paul Fultz II's avatar
Paul Fultz II committed
507
    migraphx::module m2;
508

509
    {
Paul Fultz II's avatar
Paul Fultz II committed
510
511
        auto x     = m2.add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
        auto slice = m2.add_instruction(
512
513
514
            migraphx::make_op("slice",
                              {{"axes", {0, 1}}, {"starts", {32, 0}}, {"ends", {128, 32}}}),
            x);
Paul Fultz II's avatar
Paul Fultz II committed
515
        m2.add_return({slice});
516
    }
Paul Fultz II's avatar
Paul Fultz II committed
517
    EXPECT(m1 == m2);
518
519
}

Paul's avatar
Paul committed
520
int main(int argc, const char* argv[]) { test::run(argc, argv); }