"sgl-router/vscode:/vscode.git/clone" did not exist on "63cfe1b0322d09a073a94efaf8fa52992151207f"
simplify_reshapes_test.cpp 21.6 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/pass_manager.hpp>
Paul's avatar
Paul committed
4
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/instruction.hpp>
Paul Fultz II's avatar
Paul Fultz II committed
6
#include <migraphx/generate.hpp>
Paul's avatar
Paul committed
7
#include <basic_ops.hpp>
8
9
10
11
#include <migraphx/make_op.hpp>

#include <migraphx/serialize.hpp>

Paul's avatar
Paul committed
12
13
#include <test.hpp>

14
void run_pass(migraphx::program& p)
Paul's avatar
Paul committed
15
{
16
17
    auto* mm = p.get_main_module();
    migraphx::run_passes(*mm, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}});
18
}
Paul's avatar
Paul committed
19

Paul's avatar
Paul committed
20
TEST_CASE(double_contig)
Paul's avatar
Paul committed
21
{
Paul's avatar
Paul committed
22
    migraphx::program p;
23
24
25

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
26
27
28
    auto t1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
    auto c1  = mm->add_instruction(migraphx::make_op("contiguous"), t1);
    auto c2  = mm->add_instruction(migraphx::make_op("contiguous"), c1);
29
    mm->add_return({c2});
30
31
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
32
    run_pass(p);
33
34
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Shucai Xiao's avatar
Shucai Xiao committed
35
    EXPECT(std::distance(mm->begin(), mm->end()) == 4);
36
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
37
    EXPECT(result != get_2x2());
Paul's avatar
Paul committed
38
39
}

Paul's avatar
Paul committed
40
TEST_CASE(double_transpose)
Paul's avatar
Paul committed
41
{
Paul's avatar
Paul committed
42
    migraphx::program p;
43
44
45

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
46
47
    auto t1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
    auto t2  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t1);
48
    mm->add_return({t2});
49
50
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
51
    run_pass(p);
52
53
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Shucai Xiao's avatar
Shucai Xiao committed
54
    EXPECT(std::distance(mm->begin(), mm->end()) == 2);
55
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
56
57
58
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
59
TEST_CASE(double_transpose_contig)
Paul's avatar
Paul committed
60
{
Paul's avatar
Paul committed
61
    migraphx::program p;
62
63
64

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
65
66
67
68
    auto t1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
    auto c1  = mm->add_instruction(migraphx::make_op("contiguous"), t1);
    auto t2  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), c1);
    auto c2  = mm->add_instruction(migraphx::make_op("contiguous"), t2);
69
    mm->add_return({c2});
70
71
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
72
    run_pass(p);
73
74
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Shucai Xiao's avatar
Shucai Xiao committed
75
    EXPECT(std::distance(mm->begin(), mm->end()) == 2);
76
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
77
78
79
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
80
TEST_CASE(single_transpose)
Paul's avatar
Paul committed
81
{
Paul's avatar
Paul committed
82
    migraphx::program p;
83
84
85

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
86
    auto t1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
87
    mm->add_return({t1});
88
89
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
90
    run_pass(p);
91
92
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
Shucai Xiao's avatar
Shucai Xiao committed
93
    EXPECT(std::distance(mm->begin(), mm->end()) == 3);
94
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
95
96
97
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
98
TEST_CASE(double_transpose_sin_pass)
Paul's avatar
Paul committed
99
{
Paul's avatar
Paul committed
100
    migraphx::program p;
101
102
103

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
104
105
    auto t1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
    mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t1);
106
107
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
108
    run_pass(p);
109
110
    EXPECT(p.get_output_shapes().back().standard());
    EXPECT(not p.get_output_shapes().back().transposed());
Paul's avatar
Paul committed
111
    // TODO: Fix this
Shucai Xiao's avatar
Shucai Xiao committed
112
    // EXPECT(std::distance(mm->begin(), mm->end()) == 1);
113
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
114
115
116
    EXPECT(result == get_2x2());
}

Paul's avatar
Paul committed
117
TEST_CASE(single_transpose_sin_pass)
Paul's avatar
Paul committed
118
{
Paul's avatar
Paul committed
119
    migraphx::program p;
120
121
122

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(get_2x2());
123
    mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
124
125
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
126
    run_pass(p);
127
128
    EXPECT(not p.get_output_shapes().back().standard());
    EXPECT(p.get_output_shapes().back().transposed());
Shucai Xiao's avatar
Shucai Xiao committed
129
    EXPECT(std::distance(mm->begin(), mm->end()) == 2);
130
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
131
132
133
    EXPECT(result != get_2x2());
}

Paul's avatar
Paul committed
134
135
136
TEST_CASE(reshape_transpose)
{
    migraphx::program p;
137
138
139
140

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
    auto x   = mm->add_parameter("x", s);
141
142
143
144
    auto r1  = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 4, 28, 56, 56}}}), x);
    auto t   = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 1, 3, 4}}}), r1);
    auto ct  = mm->add_instruction(migraphx::make_op("contiguous"), t);
    auto r2  = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 112, 56, 56}}}), ct);
145
    mm->add_return({r2});
146
    EXPECT(p.get_output_shapes().back() == s);
Shucai Xiao's avatar
Shucai Xiao committed
147
    auto n = std::distance(mm->begin(), mm->end());
148
    run_pass(p);
149
    EXPECT(p.get_output_shapes().back() == s);
Shucai Xiao's avatar
Shucai Xiao committed
150
    EXPECT(std::distance(mm->begin(), mm->end()) == n);
Paul's avatar
Paul committed
151
152
}

Paul's avatar
Paul committed
153
154
155
TEST_CASE(transpose_contiguous)
{
    migraphx::program p;
156
157
158
159

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x   = mm->add_parameter("x", s);
160
161
    auto t   = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), x);
    auto c1  = mm->add_instruction(migraphx::make_op("contiguous"), t);
162
    mm->add_return({c1});
163
    auto out_shape = p.get_output_shapes().back();
Shucai Xiao's avatar
Shucai Xiao committed
164
    auto n         = std::distance(mm->begin(), mm->end());
165
    run_pass(p);
166
    EXPECT(p.get_output_shapes().back() == out_shape);
Shucai Xiao's avatar
Shucai Xiao committed
167
    EXPECT(std::distance(mm->begin(), mm->end()) == n);
Paul's avatar
Paul committed
168
169
170
171
172
}

TEST_CASE(transpose_double_contiguous)
{
    migraphx::program p;
173
174
175
176

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {4, 4}};
    auto x   = mm->add_parameter("x", s);
177
178
179
    auto t   = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), x);
    auto c1  = mm->add_instruction(migraphx::make_op("contiguous"), t);
    auto c2  = mm->add_instruction(migraphx::make_op("contiguous"), c1);
180
    mm->add_return({c2});
181
    auto out_shape = p.get_output_shapes().back();
Shucai Xiao's avatar
Shucai Xiao committed
182
    auto n         = std::distance(mm->begin(), mm->end());
183
    run_pass(p);
184
    EXPECT(p.get_output_shapes().back() == out_shape);
Shucai Xiao's avatar
Shucai Xiao committed
185
    EXPECT(std::distance(mm->begin(), mm->end()) == n - 1);
186
    EXPECT(mm->has_instruction(t));
Paul's avatar
Paul committed
187
188
}

189
190
191
TEST_CASE(transpose_partial1)
{
    migraphx::program p;
192
193
194
195

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x   = mm->add_parameter("x", s);
196
197
    auto t1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), x);
    auto t2  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 2, 0}}}), t1);
198
    mm->add_return({t2});
199
    auto out_shape = p.get_output_shapes().back();
Shucai Xiao's avatar
Shucai Xiao committed
200
    auto n         = std::distance(mm->begin(), mm->end());
201
    run_pass(p);
202
    EXPECT(p.get_output_shapes().back() == out_shape);
Shucai Xiao's avatar
Shucai Xiao committed
203
    EXPECT(std::distance(mm->begin(), mm->end()) == n - 1);
204
205
206
207
208
}

TEST_CASE(transpose_partial2)
{
    migraphx::program p;
209
210
211
212

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x   = mm->add_parameter("x", s);
213
214
215
    auto t1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), x);
    auto t2  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 2, 0}}}), t1);
    auto t3  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t2);
216
    mm->add_return({t3});
217
    auto out_shape = p.get_output_shapes().back();
Shucai Xiao's avatar
Shucai Xiao committed
218
    auto n         = std::distance(mm->begin(), mm->end());
219
    run_pass(p);
220
    EXPECT(p.get_output_shapes().back() == out_shape);
Shucai Xiao's avatar
Shucai Xiao committed
221
    EXPECT(std::distance(mm->begin(), mm->end()) == n - 2);
222
223
224
225
226
}

TEST_CASE(transpose_partial3)
{
    migraphx::program p;
227
228
229
230

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x   = mm->add_parameter("x", s);
231
232
233
234
    auto t1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), x);
    auto t2  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 2, 0}}}), t1);
    auto t3  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t2);
    auto t4  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t3);
235
    mm->add_return({t4});
236
    auto out_shape = p.get_output_shapes().back();
Shucai Xiao's avatar
Shucai Xiao committed
237
    auto n         = std::distance(mm->begin(), mm->end());
238
    run_pass(p);
239
    EXPECT(p.get_output_shapes().back() == out_shape);
Shucai Xiao's avatar
Shucai Xiao committed
240
    EXPECT(std::distance(mm->begin(), mm->end()) == n - 3);
241
242
}

Paul's avatar
Paul committed
243
244
245
TEST_CASE(nop_transpose1)
{
    migraphx::program p;
246
247
248
249

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x   = mm->add_parameter("x", s);
250
    auto t   = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), x);
251
    mm->add_return({t});
252
    auto out_shape = p.get_output_shapes().back();
Shucai Xiao's avatar
Shucai Xiao committed
253
    auto n         = std::distance(mm->begin(), mm->end());
254
    run_pass(p);
255
    EXPECT(p.get_output_shapes().back() == out_shape);
Shucai Xiao's avatar
Shucai Xiao committed
256
    EXPECT(std::distance(mm->begin(), mm->end()) == n - 1);
Paul's avatar
Paul committed
257
258
259
260
261
}

TEST_CASE(nop_transpose2)
{
    migraphx::program p;
262
263
264
265

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x   = mm->add_parameter("x", s);
266
267
268
269
    auto t1  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), x);
    auto t2  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), t1);
    auto t3  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), t2);
    auto t4  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), t3);
270
    mm->add_instruction(pass_op{}, t4);
271
    auto out_shape = p.get_output_shapes().back();
Shucai Xiao's avatar
Shucai Xiao committed
272
    auto n         = std::distance(mm->begin(), mm->end());
273
    run_pass(p);
274
    EXPECT(p.get_output_shapes().back() == out_shape);
Shucai Xiao's avatar
Shucai Xiao committed
275
    EXPECT(std::distance(mm->begin(), mm->end()) == n - 4);
Paul's avatar
Paul committed
276
277
278
279
280
}

TEST_CASE(nop_transpose3)
{
    migraphx::program p;
281
282

    auto* mm    = p.get_main_module();
Paul's avatar
Paul committed
283
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
284
285
    auto x      = mm->add_parameter("x", s);
    auto y      = mm->add_parameter("y", s);
286
287
288
    auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), x, y);
    auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2, 3}}}), concat);
    auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t1);
289
    mm->add_return({t2});
290
    auto out_shape = p.get_output_shapes().back();
Shucai Xiao's avatar
Shucai Xiao committed
291
    auto n         = std::distance(mm->begin(), mm->end());
292
    run_pass(p);
293
    EXPECT(p.get_output_shapes().back() == out_shape);
Shucai Xiao's avatar
Shucai Xiao committed
294
    EXPECT(std::distance(mm->begin(), mm->end()) == n - 1);
Paul's avatar
Paul committed
295
}
Paul Fultz II's avatar
Paul Fultz II committed
296
297
298
299

TEST_CASE(nop_convert)
{
    migraphx::program p;
300
301
302
303

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
    auto x   = mm->add_parameter("x", s);
304
305
306
307
    auto t   = mm->add_instruction(
        migraphx::make_op("convert",
                          {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
        x);
308
    mm->add_return({t});
Paul Fultz II's avatar
Paul Fultz II committed
309
    auto out_shape = p.get_output_shapes().back();
Shucai Xiao's avatar
Shucai Xiao committed
310
    auto n         = std::distance(mm->begin(), mm->end());
Paul Fultz II's avatar
Paul Fultz II committed
311
312
    run_pass(p);
    EXPECT(p.get_output_shapes().back() == out_shape);
Shucai Xiao's avatar
Shucai Xiao committed
313
    EXPECT(std::distance(mm->begin(), mm->end()) == n - 1);
Paul Fultz II's avatar
Paul Fultz II committed
314
}
Paul's avatar
Paul committed
315
316
317
318

TEST_CASE(concat_transpose1)
{
    migraphx::program p;
319
320

    auto* mm    = p.get_main_module();
Paul's avatar
Paul committed
321
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
322
323
    auto x      = mm->add_parameter("x", s);
    auto y      = mm->add_parameter("y", s);
324
325
326
327
    auto xt     = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), x);
    auto yt     = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), y);
    auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), xt, yt);
    auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), concat);
328
    mm->add_return({t});
329
    auto out_shape = p.get_output_shapes().back();
Shucai Xiao's avatar
Shucai Xiao committed
330
    auto n         = std::distance(mm->begin(), mm->end());
331
    run_pass(p);
332
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Shucai Xiao's avatar
Shucai Xiao committed
333
    EXPECT(std::distance(mm->begin(), mm->end()) == n - 3);
Paul's avatar
Paul committed
334
    auto new_concat =
Shucai Xiao's avatar
Shucai Xiao committed
335
336
        std::find_if(mm->begin(), mm->end(), [](auto ins) { return ins.name() == "concat"; });
    EXPECT(bool{new_concat != mm->end()});
Paul's avatar
Paul committed
337
338
339
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
}

Paul's avatar
Paul committed
340
341
342
TEST_CASE(concat_transpose2)
{
    migraphx::program p;
343
344

    auto* mm    = p.get_main_module();
Paul's avatar
Paul committed
345
    auto s      = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
346
347
    auto x      = mm->add_parameter("x", s);
    auto y      = mm->add_parameter("y", s);
348
349
350
351
    auto xt     = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), x);
    auto yt     = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), y);
    auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", -1}}), xt, yt);
    auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat);
352
    mm->add_return({t});
353
    auto out_shape = p.get_output_shapes().back();
Shucai Xiao's avatar
Shucai Xiao committed
354
    auto n         = std::distance(mm->begin(), mm->end());
355
    run_pass(p);
356
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Shucai Xiao's avatar
Shucai Xiao committed
357
    EXPECT(std::distance(mm->begin(), mm->end()) == n - 2);
Paul's avatar
Paul committed
358
    auto new_concat =
Shucai Xiao's avatar
Shucai Xiao committed
359
360
        std::find_if(mm->begin(), mm->end(), [](auto ins) { return ins.name() == "concat"; });
    EXPECT(bool{new_concat != mm->end()});
Paul's avatar
Paul committed
361
362
363
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}

364
365
366
TEST_CASE(concat_transpose3)
{
    migraphx::program p;
367
368
369
370
371

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x   = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}});
    auto y   = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 4}});
372
373
374
375
    auto xt  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), x);
    auto yt  = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), y);
    auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), xt, yt);
    auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat);
376
    mm->add_return({t});
377
    auto out_shape = p.get_output_shapes().back();
Shucai Xiao's avatar
Shucai Xiao committed
378
    auto n         = std::distance(mm->begin(), mm->end());
379
    run_pass(p);
380
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Shucai Xiao's avatar
Shucai Xiao committed
381
    EXPECT(std::distance(mm->begin(), mm->end()) == n - 2);
382
    auto new_concat =
Shucai Xiao's avatar
Shucai Xiao committed
383
384
        std::find_if(mm->begin(), mm->end(), [](auto ins) { return ins.name() == "concat"; });
    EXPECT(bool{new_concat != mm->end()});
385
386
387
    EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}

Shucai Xiao's avatar
Shucai Xiao committed
388
389
390
TEST_CASE(concat_transpose4)
{
    migraphx::program p;
391
    auto* mm    = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
392
393
    auto sx     = migraphx::shape{migraphx::shape::float_type, {1, 1, 12, 64}};
    auto sy     = migraphx::shape{migraphx::shape::float_type, {1, 12, 1, 64}};
394
395
    auto x      = mm->add_parameter("x", sx);
    auto y      = mm->add_parameter("y", sy);
396
397
398
399
    auto xt     = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), x);
    auto yt     = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), y);
    auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), xt, yt);
    auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat);
400
    mm->add_return({t});
Shucai Xiao's avatar
Shucai Xiao committed
401
402
403

    migraphx::program p1 = p;
    run_pass(p);
404

Shucai Xiao's avatar
Shucai Xiao committed
405
406
407
    EXPECT(p1 == p);
}

Paul Fultz II's avatar
Paul Fultz II committed
408
409
410
TEST_CASE(nested_concat)
{
    migraphx::program p;
411
412

    auto* mm     = p.get_main_module();
Paul Fultz II's avatar
Paul Fultz II committed
413
    auto s       = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
414
415
    auto x       = mm->add_parameter("x", s);
    auto y       = mm->add_parameter("y", s);
416
417
418
419
    auto concat1 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y);
    auto concat2 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), y, x);
    auto concat3 =
        mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), concat1, concat2);
420
    mm->add_return({concat3});
421
    auto out_shape = p.get_output_shapes().back();
Shucai Xiao's avatar
Shucai Xiao committed
422
    auto n         = std::distance(mm->begin(), mm->end());
423
    run_pass(p);
424
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Shucai Xiao's avatar
Shucai Xiao committed
425
426
427
    EXPECT(std::distance(mm->begin(), mm->end()) == n - 2);
    EXPECT(std::count_if(mm->begin(), mm->end(), [](auto ins) { return ins.name() == "concat"; }) ==
           1);
Paul Fultz II's avatar
Paul Fultz II committed
428
429
430
431
432
}

TEST_CASE(nested_concat_partial)
{
    migraphx::program p;
433
434
435
436
437
438

    auto* mm = p.get_main_module();
    auto s   = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
    auto x   = mm->add_parameter("x", s);
    auto y   = mm->add_parameter("y", s);
    auto l   = mm->add_literal(
Paul Fultz II's avatar
Paul Fultz II committed
439
        migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1, 4, 3, 4}}));
440
441
442
443
    auto concat1 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y);
    auto concat2 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), y, x);
    auto concat3 =
        mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), concat1, concat2, l);
444
    mm->add_return({concat3});
445
    auto out_shape = p.get_output_shapes().back();
Shucai Xiao's avatar
Shucai Xiao committed
446
    auto n         = std::distance(mm->begin(), mm->end());
447
    run_pass(p);
448
    EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
Shucai Xiao's avatar
Shucai Xiao committed
449
450
451
    EXPECT(std::distance(mm->begin(), mm->end()) == n - 2);
    EXPECT(std::count_if(mm->begin(), mm->end(), [](auto ins) { return ins.name() == "concat"; }) ==
           1);
Paul Fultz II's avatar
Paul Fultz II committed
452
453
}

454
455
456
TEST_CASE(multibroadcast_simplify)
{
    migraphx::program p;
457
458

    auto* mm = p.get_main_module();
459
460
    std::vector<size_t> s_lens{1, 2, 3, 4};
    auto s = migraphx::shape{migraphx::shape::float_type, s_lens};
461
    auto x = mm->add_parameter("x", s);
462
463
    auto y = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s_lens}}), x);
    mm->add_instruction(migraphx::make_op("mul"), y, y);
Shucai Xiao's avatar
Shucai Xiao committed
464
    auto n = std::distance(mm->begin(), mm->end());
465
    run_pass(p);
Shucai Xiao's avatar
Shucai Xiao committed
466
    EXPECT(std::distance(mm->begin(), mm->end()) == n - 1);
467
468
}

469
470
471
TEST_CASE(double_slice1)
{
    migraphx::program p1;
472
    auto* mm1 = p1.get_main_module();
473
    {
474
        auto x      = mm1->add_parameter("x", {migraphx::shape::int32_type, {256}});
475
476
477
478
        auto slice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {256}}}), x);
        auto slice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {64}}}), slice1);
479
        mm1->add_return({slice2});
480
481
482
483
    }
    run_pass(p1);

    migraphx::program p2;
484
    auto* mm2 = p2.get_main_module();
485
    {
486
        auto x     = mm2->add_parameter("x", {migraphx::shape::int32_type, {256}});
487
488
        auto slice = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {96}}}), x);
489
        mm2->add_return({slice});
490
491
492
493
494
495
496
    }
    EXPECT(p1 == p2);
}

TEST_CASE(double_slice2)
{
    migraphx::program p1;
497
    auto* mm1 = p1.get_main_module();
498
    {
499
        auto x      = mm1->add_parameter("x", {migraphx::shape::int32_type, {256}});
500
501
502
503
        auto slice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {128}}}), x);
        auto slice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {32}}}), slice1);
504
        mm1->add_return({slice2});
505
506
507
508
    }
    run_pass(p1);

    migraphx::program p2;
509
    auto* mm2 = p2.get_main_module();
510
    {
511
        auto x     = mm2->add_parameter("x", {migraphx::shape::int32_type, {256}});
512
513
        auto slice = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {64}}}), x);
514
        mm2->add_return({slice});
515
516
517
518
519
520
521
    }
    EXPECT(p1 == p2);
}

TEST_CASE(double_slice_multi_axes)
{
    migraphx::program p1;
522
    auto* mm1 = p1.get_main_module();
523
    {
524
        auto x      = mm1->add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
525
526
527
528
        auto slice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {128}}}), x);
        auto slice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {32}}}), slice1);
529
        mm1->add_return({slice2});
530
531
532
533
    }
    run_pass(p1);

    migraphx::program p2;
534
535

    auto* mm2 = p2.get_main_module();
536
    {
537
        auto x     = mm2->add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
538
539
540
541
        auto slice = mm2->add_instruction(
            migraphx::make_op("slice",
                              {{"axes", {0, 1}}, {"starts", {32, 0}}, {"ends", {128, 32}}}),
            x);
542
        mm2->add_return({slice});
543
544
545
546
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
547
int main(int argc, const char* argv[]) { test::run(argc, argv); }