simplify_algebra_test.cpp 86.6 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/pass_manager.hpp>
Paul's avatar
Paul committed
4
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
5
6
7
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
8
9
10
#include <basic_ops.hpp>
#include <test.hpp>

11
void run_pass(migraphx::program& p)
Paul's avatar
Paul committed
12
{
13
14
    migraphx::run_passes(*p.get_main_module(),
                         {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}});
15
}
Paul's avatar
Paul committed
16

Paul's avatar
Paul committed
17
TEST_CASE(simplify_add1)
Paul's avatar
Paul committed
18
{
Paul's avatar
Paul committed
19
    migraphx::program p1;
Paul's avatar
Paul committed
20
    {
21
22
23
24
25
26
27
28
29
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto one  = mm1->add_literal(1);
        auto two  = mm1->add_literal(2);
        auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, one);
        auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, two);
        auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
        mm1->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
30
    }
31
    run_pass(p1);
Paul's avatar
Paul committed
32

Paul's avatar
Paul committed
33
    migraphx::program p2;
Paul's avatar
Paul committed
34
    {
35
36
37
38
39
40
41
42
43
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto one  = mm2->add_literal(1);
        auto two  = mm2->add_literal(2);
        auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two);
        auto sum2 = mm2->add_instruction(migraphx::op::add{}, x, y);
        auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum2, sum1);
        mm2->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
44
45
46
47
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
48
TEST_CASE(simplify_add2)
Paul's avatar
Paul committed
49
{
Paul's avatar
Paul committed
50
    migraphx::program p1;
Paul's avatar
Paul committed
51
    {
52
53
54
55
56
57
58
59
60
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto one  = mm1->add_literal(1);
        auto two  = mm1->add_literal(2);
        auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, x);
        auto sum2 = mm1->add_instruction(migraphx::op::add{}, two, y);
        auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
        mm1->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
61
    }
62
    run_pass(p1);
Paul's avatar
Paul committed
63

Paul's avatar
Paul committed
64
    migraphx::program p2;
Paul's avatar
Paul committed
65
    {
66
67
68
69
70
71
72
73
74
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto one  = mm2->add_literal(1);
        auto two  = mm2->add_literal(2);
        auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two);
        auto sum2 = mm2->add_instruction(migraphx::op::add{}, x, y);
        auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum2, sum1);
        mm2->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
75
76
77
78
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
79
TEST_CASE(simplify_add3)
Paul's avatar
Paul committed
80
{
Paul's avatar
Paul committed
81
    migraphx::program p1;
Paul's avatar
Paul committed
82
    {
83
84
85
86
87
88
89
90
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto one  = mm1->add_literal(1);
        auto two  = mm1->add_literal(2);
        auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, x);
        auto sum2 = mm1->add_instruction(migraphx::op::add{}, one, two);
        auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
        mm1->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
91
    }
92
    run_pass(p1);
Paul's avatar
Paul committed
93

Paul's avatar
Paul committed
94
    migraphx::program p2;
Paul's avatar
Paul committed
95
    {
96
97
98
99
100
101
102
103
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto one  = mm2->add_literal(1);
        auto two  = mm2->add_literal(2);
        auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two);
        auto sum2 = mm2->add_instruction(migraphx::op::add{}, one, sum1);
        auto sum3 = mm2->add_instruction(migraphx::op::add{}, x, sum2);
        mm2->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
104
105
106
107
    }
    EXPECT(p1 == p2);
}

108
109
110
111
112
113
114
TEST_CASE(simplify_add_broadcast1)
{
    migraphx::shape inner{migraphx::shape::int32_type, {2}};
    migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
    migraphx::op::broadcast b{1, {1, 2, 3, 3}};
    migraphx::program p1;
    {
115
116
117
118
119
120
121
122
123
124
125
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", outer);
        auto y    = mm1->add_parameter("y", outer);
        auto one  = mm1->add_literal({inner, {1, 1}});
        auto oneb = mm1->add_instruction(b, one);
        auto two  = mm1->add_literal({inner, {2, 2}});
        auto twob = mm1->add_instruction(b, two);
        auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
        auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
        auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
        mm1->add_instruction(pass_op{}, sum3);
126
    }
127
    run_pass(p1);
128
129
130

    migraphx::program p2;
    {
131
132
133
134
135
136
137
138
139
140
        auto* mm2  = p2.get_main_module();
        auto x     = mm2->add_parameter("x", outer);
        auto y     = mm2->add_parameter("y", outer);
        auto one   = mm2->add_literal({inner, {1, 1}});
        auto two   = mm2->add_literal({inner, {2, 2}});
        auto sum1  = mm2->add_instruction(migraphx::op::add{}, one, two);
        auto sum1b = mm2->add_instruction(b, sum1);
        auto sum2  = mm2->add_instruction(migraphx::op::add{}, x, y);
        auto sum3  = mm2->add_instruction(migraphx::op::add{}, sum2, sum1b);
        mm2->add_instruction(pass_op{}, sum3);
141
142
143
144
145
146
147
148
149
150
151
    }
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_add_broadcast2)
{
    migraphx::shape inner{migraphx::shape::int32_type, {2}};
    migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
    migraphx::op::broadcast b{1, {1, 2, 3, 3}};
    auto create_program = [&] {
        migraphx::program p;
152
153
154
155
156
157
158
159
160
161
        auto* mm  = p.get_main_module();
        auto x    = mm->add_parameter("x", outer);
        auto y    = mm->add_parameter("y", outer);
        auto one  = mm->add_literal({inner, {1, 1}});
        auto oneb = mm->add_instruction(b, one);
        auto two = mm->add_literal({outer, {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}});
        auto sum1 = mm->add_instruction(migraphx::op::add{}, x, y);
        auto sum2 = mm->add_instruction(migraphx::op::add{}, oneb, two);
        auto sum3 = mm->add_instruction(migraphx::op::add{}, sum2, sum1);
        mm->add_instruction(pass_op{}, sum3);
162
163
164
        return p;
    };
    migraphx::program p1 = create_program();
165
    run_pass(p1);
166
167
168
169
170

    migraphx::program p2 = create_program();
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
171
// TODO: Add test case
172
// TEST_CASE(simplify_add4)
Paul's avatar
Paul committed
173
174
void simplify_add4()
{
Paul's avatar
Paul committed
175
    migraphx::program p1;
Paul's avatar
Paul committed
176
    {
177
178
179
180
181
182
183
184
185
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto one  = mm1->add_literal(1);
        auto two  = mm1->add_literal(2);
        auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, x);
        auto sum2 = mm1->add_instruction(migraphx::op::add{}, sum1, y);
        auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum2, two);
        mm1->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
186
    }
187
    run_pass(p1);
Paul's avatar
Paul committed
188

Paul's avatar
Paul committed
189
    migraphx::program p2;
Paul's avatar
Paul committed
190
    {
191
192
193
194
195
196
197
198
199
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto one  = mm2->add_literal(1);
        auto two  = mm2->add_literal(2);
        auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two);
        auto sum2 = mm2->add_instruction(migraphx::op::add{}, x, y);
        auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum2, sum1);
        mm2->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
200
201
202
203
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
204
205
206
TEST_CASE(simplify_mul_conv1)
{
    migraphx::program p;
207
208
209
210
211
212
213
214
215
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
    auto conv = mm->add_instruction(migraphx::op::convolution{{1, 1}, {2, 2}, {1, 1}}, x, w);
    auto a    = mm->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}}));
    auto b    = mm->add_instruction(migraphx::op::broadcast{1, {1, 256, 14, 14}}, a);
    auto mul  = mm->add_instruction(migraphx::op::mul{}, conv, b);
    mm->add_instruction(pass_op{}, mul);
Paul's avatar
Paul committed
216
    EXPECT(conv->outputs().front()->name() == "mul");
217
    run_pass(p);
Paul's avatar
Paul committed
218
219
220
221
222
    auto new_conv =
        std::find_if(p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; });
    EXPECT(new_conv->outputs().front()->name() != "mul");
}

223
224
225
226
TEST_CASE(simplify_mul_slice_conv1)
{
    migraphx::program p1;
    {
227
228
229
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
        auto w    = mm1->add_literal(
230
            migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
231
232
233
234
235
236
237
238
        auto conv   = mm1->add_instruction(migraphx::op::convolution{}, x, w);
        auto slice1 = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
        auto a = mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
        auto b = mm1->add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a);
        auto mul    = mm1->add_instruction(migraphx::op::mul{}, slice1, b);
        auto slice2 = mm1->add_instruction(migraphx::op::slice{{1}, {384}, {768}}, conv);
        auto add    = mm1->add_instruction(migraphx::op::add{}, mul, slice2);
        mm1->add_instruction(pass_op{}, add);
239
240
241
242
243
    }
    run_pass(p1);

    migraphx::program p2;
    {
244
245
246
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
        auto w    = mm2->add_literal(
247
            migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
248
249
250
251
252
253
254
255
256
257
258
        auto wslice1 = mm2->add_instruction(migraphx::op::slice{{0}, {0}, {384}}, w);
        auto a = mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
        auto b = mm2->add_instruction(migraphx::op::broadcast{0, {384, 1024, 1, 1}}, a);
        auto mul     = mm2->add_instruction(migraphx::op::mul{}, b, wslice1);
        auto wslice2 = mm2->add_instruction(migraphx::op::slice{{0}, {384}, {768}}, w);
        auto concat  = mm2->add_instruction(migraphx::op::concat{0}, mul, wslice2);
        auto conv    = mm2->add_instruction(migraphx::op::convolution{}, x, concat);
        auto slice1  = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
        auto slice2  = mm2->add_instruction(migraphx::op::slice{{1}, {384}, {768}}, conv);
        auto add     = mm2->add_instruction(migraphx::op::add{}, slice1, slice2);
        mm2->add_instruction(pass_op{}, add);
259
260
261
262
263
264
265
266
    }
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_mul_slice_conv_overlapping_slice)
{
    migraphx::program p1;
    {
267
268
269
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
        auto w    = mm1->add_literal(
270
            migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
271
272
273
274
275
276
277
278
        auto conv   = mm1->add_instruction(migraphx::op::convolution{}, x, w);
        auto slice1 = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
        auto a = mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
        auto b = mm1->add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a);
        auto mul    = mm1->add_instruction(migraphx::op::mul{}, slice1, b);
        auto slice2 = mm1->add_instruction(migraphx::op::slice{{1}, {383}, {767}}, conv);
        auto add    = mm1->add_instruction(migraphx::op::add{}, mul, slice2);
        mm1->add_instruction(pass_op{}, add);
279
280
281
282
283
284
285
286
287
288
    }
    migraphx::program p2 = p1;
    run_pass(p1);
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_mul_slice_conv_not_all_slice)
{
    migraphx::program p1;
    {
289
290
291
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
        auto w    = mm1->add_literal(
292
            migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
293
294
295
296
297
298
        auto conv   = mm1->add_instruction(migraphx::op::convolution{}, x, w);
        auto slice1 = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
        auto a = mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
        auto b = mm1->add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a);
        auto mul = mm1->add_instruction(migraphx::op::mul{}, slice1, b);
        auto c   = mm1->add_literal(
299
            migraphx::generate_literal({migraphx::shape::int32_type, {1, 768, 17, 17}}));
300
301
302
        auto add    = mm1->add_instruction(migraphx::op::add{}, conv, c);
        auto concat = mm1->add_instruction(migraphx::op::concat{1}, mul, add);
        mm1->add_instruction(pass_op{}, concat);
303
304
305
306
307
308
    }
    migraphx::program p2 = p1;
    run_pass(p1);
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
309
310
311
312
TEST_CASE(simplify_mul_add)
{
    migraphx::program p1;
    {
313
314
315
316
317
318
319
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto one  = mm1->add_literal(1);
        auto two  = mm1->add_literal(2);
        auto sum  = mm1->add_instruction(migraphx::op::add{}, one, x);
        auto mul  = mm1->add_instruction(migraphx::op::mul{}, sum, two);
        mm1->add_instruction(pass_op{}, mul);
Paul's avatar
Paul committed
320
    }
321
    run_pass(p1);
Paul's avatar
Paul committed
322
323
324

    migraphx::program p2;
    {
325
326
327
328
329
330
331
332
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto one  = mm2->add_literal(1);
        auto two  = mm2->add_literal(2);
        auto mul1 = mm2->add_instruction(migraphx::op::mul{}, two, x);
        auto mul2 = mm2->add_instruction(migraphx::op::mul{}, two, one);
        auto sum  = mm2->add_instruction(migraphx::op::add{}, mul1, mul2);
        mm2->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
333
334
335
336
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
337
338
339
340
341
TEST_CASE(simplify_inner_broadcast)
{
    auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
    migraphx::program p1;
    {
342
343
344
345
346
347
348
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto xb   = mm1->add_instruction(b, x);
        auto yb   = mm1->add_instruction(b, y);
        auto sum  = mm1->add_instruction(migraphx::op::add{}, xb, yb);
        mm1->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
349
    }
350
    run_pass(p1);
Paul's avatar
Paul committed
351
352
353

    migraphx::program p2;
    {
354
355
356
357
358
359
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto sum  = mm2->add_instruction(migraphx::op::add{}, x, y);
        auto sumb = mm2->add_instruction(b, sum);
        mm2->add_instruction(pass_op{}, sumb);
Paul's avatar
Paul committed
360
361
362
363
    }
    EXPECT(p1 == p2);
}

364
365
366
TEST_CASE(simplify_add_conv1)
{
    migraphx::program p;
367
368
369
370
371
372
373
374
375
376
377
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 28}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 3, 3}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
    auto v = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 3, 3}}));
    auto conv1 = mm->add_instruction(migraphx::op::convolution{}, x, w);
    auto conv2 = mm->add_instruction(migraphx::op::convolution{}, y, v);
    auto sum   = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
    mm->add_instruction(pass_op{}, sum);
378
    auto s = p.get_output_shapes().back();
379
    run_pass(p);
380
    EXPECT(s == p.get_output_shapes().back());
381
382
383
384
385
386
387
    EXPECT(std::count_if(
               p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
}

TEST_CASE(simplify_add_conv_no_fusion_7x7_diff_strides)
{
    migraphx::program p;
388
389
390
391
392
393
394
395
396
397
398
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 7, 7}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
    auto v = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 7, 7}}));
    auto conv1 = mm->add_instruction(migraphx::op::convolution{}, x, w);
    auto conv2 = mm->add_instruction(migraphx::op::convolution{{0, 0}, {3, 3}}, y, v);
    auto sum   = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
    mm->add_instruction(pass_op{}, sum);
399
    auto s = p.get_output_shapes().back();
400
    run_pass(p);
401
    EXPECT(s == p.get_output_shapes().back());
402
403
404
405
406
407
408
409
    // No fusion
    EXPECT(std::count_if(
               p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
}

TEST_CASE(simplify_add_conv_1x1_diff_strides1)
{
    migraphx::program p;
410
411
412
413
414
415
416
417
418
419
420
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
    auto v = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
    auto conv1 = mm->add_instruction(migraphx::op::convolution{}, x, w);
    auto conv2 = mm->add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, y, v);
    auto sum   = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
    mm->add_instruction(pass_op{}, sum);
421
    auto s = p.get_output_shapes().back();
422
    run_pass(p);
423
    EXPECT(s == p.get_output_shapes().back());
424
425
426
427
428
429
430
    EXPECT(std::count_if(
               p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
}

TEST_CASE(simplify_add_conv_1x1_diff_strides2)
{
    migraphx::program p;
431
432
433
434
435
436
437
438
439
440
441
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 28}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 14, 14}});
    auto v = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
    auto conv1 = mm->add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, x, w);
    auto conv2 = mm->add_instruction(migraphx::op::convolution{}, y, v);
    auto sum   = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
    mm->add_instruction(pass_op{}, sum);
442
    auto s = p.get_output_shapes().back();
443
444
445
446
447
448
449
450
451
    run_pass(p);
    EXPECT(s == p.get_output_shapes().back());
    EXPECT(std::count_if(
               p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
}

TEST_CASE(simplify_add_conv_1x1_diff_strides_odd)
{
    migraphx::program p;
452
453
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 54, 83, 83}});
454
    auto w =
455
456
        mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {54, 54, 1, 1}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 54, 165, 165}});
457
    auto v =
458
459
460
461
462
        mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {54, 54, 1, 1}}));
    auto conv1 = mm->add_instruction(migraphx::op::convolution{}, x, w);
    auto conv2 = mm->add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, y, v);
    auto sum   = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
    mm->add_instruction(pass_op{}, sum);
463
    auto s = p.get_output_shapes().back();
464
    run_pass(p);
465
    EXPECT(s == p.get_output_shapes().back());
466
467
468
469
470
471
472
    EXPECT(std::count_if(
               p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
}

TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides1)
{
    migraphx::program p;
473
474
475
476
477
478
479
480
481
482
483
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 14}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 14, 14}});
    auto v = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
    auto conv1 = mm->add_instruction(migraphx::op::convolution{{0, 0}, {2, 1}}, x, w);
    auto conv2 = mm->add_instruction(migraphx::op::convolution{}, y, v);
    auto sum   = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
    mm->add_instruction(pass_op{}, sum);
484
    auto s = p.get_output_shapes().back();
485
    run_pass(p);
486
    EXPECT(s == p.get_output_shapes().back());
487
488
489
490
491
492
493
494
    // No fusion
    EXPECT(std::count_if(
               p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
}

TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides2)
{
    migraphx::program p;
495
496
497
498
499
500
501
502
503
504
505
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 14}});
    auto v = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
    auto conv1 = mm->add_instruction(migraphx::op::convolution{}, x, w);
    auto conv2 = mm->add_instruction(migraphx::op::convolution{{0, 0}, {2, 1}}, y, v);
    auto sum   = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
    mm->add_instruction(pass_op{}, sum);
506
    auto s = p.get_output_shapes().back();
507
    run_pass(p);
508
    EXPECT(s == p.get_output_shapes().back());
509
510
511
512
513
    // No fusion
    EXPECT(std::count_if(
               p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
}

514
515
516
517
518
TEST_CASE(simplify_concat_add_relu)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {1}};
    migraphx::program p1;
    {
519
520
521
522
523
524
525
526
527
528
529
        auto* mm1   = p1.get_main_module();
        auto x      = mm1->add_parameter("x", s);
        auto y      = mm1->add_parameter("y", s);
        auto one    = mm1->add_literal({s, {1}});
        auto two    = mm1->add_literal({s, {2}});
        auto sum1   = mm1->add_instruction(migraphx::op::add{}, x, one);
        auto relu1  = mm1->add_instruction(migraphx::op::relu{}, sum1);
        auto sum2   = mm1->add_instruction(migraphx::op::add{}, y, two);
        auto relu2  = mm1->add_instruction(migraphx::op::relu{}, sum2);
        auto concat = mm1->add_instruction(migraphx::op::concat{0}, relu1, relu2);
        mm1->add_instruction(pass_op{}, concat);
530
531
532
533
534
    }
    run_pass(p1);

    migraphx::program p2;
    {
535
536
537
538
539
540
541
542
543
544
        auto* mm2    = p2.get_main_module();
        auto x       = mm2->add_parameter("x", s);
        auto y       = mm2->add_parameter("y", s);
        auto one     = mm2->add_literal({s, {1}});
        auto two     = mm2->add_literal({s, {2}});
        auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, x, y);
        auto concat2 = mm2->add_instruction(migraphx::op::concat{0}, one, two);
        auto sum     = mm2->add_instruction(migraphx::op::add{}, concat1, concat2);
        auto relu    = mm2->add_instruction(migraphx::op::relu{}, sum);
        mm2->add_instruction(pass_op{}, relu);
545
546
547
548
    }
    EXPECT(p1 == p2);
}

549
550
551
552
553
TEST_CASE(simplify_concat_add_relu_partial)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {1}};
    migraphx::program p1;
    {
554
555
556
557
558
559
560
561
562
563
564
565
        auto* mm1   = p1.get_main_module();
        auto x      = mm1->add_parameter("x", s);
        auto y      = mm1->add_parameter("y", s);
        auto one    = mm1->add_literal({s, {1}});
        auto two    = mm1->add_literal({s, {2}});
        auto sum1   = mm1->add_instruction(migraphx::op::add{}, x, one);
        auto relu1  = mm1->add_instruction(migraphx::op::relu{}, sum1);
        auto sum2   = mm1->add_instruction(migraphx::op::add{}, y, two);
        auto relu2  = mm1->add_instruction(migraphx::op::relu{}, sum2);
        auto sum3   = mm1->add_instruction(migraphx::op::add{}, x, y);
        auto concat = mm1->add_instruction(migraphx::op::concat{0}, sum3, relu1, relu2);
        mm1->add_instruction(pass_op{}, concat);
566
567
568
569
570
    }
    run_pass(p1);

    migraphx::program p2;
    {
571
572
573
574
575
576
577
578
579
580
581
582
        auto* mm2    = p2.get_main_module();
        auto x       = mm2->add_parameter("x", s);
        auto y       = mm2->add_parameter("y", s);
        auto one     = mm2->add_literal({s, {1}});
        auto two     = mm2->add_literal({s, {2}});
        auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, x, y);
        auto concat2 = mm2->add_instruction(migraphx::op::concat{0}, one, two);
        auto sum1    = mm2->add_instruction(migraphx::op::add{}, concat1, concat2);
        auto relu    = mm2->add_instruction(migraphx::op::relu{}, sum1);
        auto sum2    = mm2->add_instruction(migraphx::op::add{}, x, y);
        auto concat  = mm2->add_instruction(migraphx::op::concat{0}, sum2, relu);
        mm2->add_instruction(pass_op{}, concat);
583
584
585
586
587
588
589
590
591
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_concat_add_relu_partial_broadcast)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
    migraphx::program p1;
    {
592
        auto* mm1   = p1.get_main_module();
593
        auto b      = migraphx::op::broadcast{1, {2, 1, 4, 5}};
594
595
596
597
598
599
600
601
602
        auto x      = mm1->add_parameter("x", s);
        auto y      = mm1->add_parameter("y", s);
        auto one    = mm1->add_literal(1);
        auto oneb   = mm1->add_instruction(b, one);
        auto two    = mm1->add_literal(2);
        auto twob   = mm1->add_instruction(b, two);
        auto sum    = mm1->add_instruction(migraphx::op::add{}, x, y);
        auto concat = mm1->add_instruction(migraphx::op::concat{1}, sum, oneb, twob);
        mm1->add_instruction(pass_op{}, concat);
603
604
605
606
607
    }
    run_pass(p1);

    migraphx::program p2;
    {
608
        auto* mm2    = p2.get_main_module();
609
        auto b       = migraphx::op::broadcast{1, {2, 2, 4, 5}};
610
611
612
613
614
615
616
617
618
        auto x       = mm2->add_parameter("x", s);
        auto y       = mm2->add_parameter("y", s);
        auto one     = mm2->add_literal(1);
        auto two     = mm2->add_literal(2);
        auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, one, two);
        auto concatb = mm2->add_instruction(b, concat1);
        auto sum     = mm2->add_instruction(migraphx::op::add{}, x, y);
        auto concat2 = mm2->add_instruction(migraphx::op::concat{1}, sum, concatb);
        mm2->add_instruction(pass_op{}, concat2);
619
620
621
622
    }
    EXPECT(p1.sort() == p2.sort());
}

623
624
625
626
627
TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
    migraphx::program p1;
    {
628
        auto* mm1   = p1.get_main_module();
629
        auto b      = migraphx::op::broadcast{1, {2, 1, 4, 5}};
630
631
632
633
634
635
636
637
638
639
640
641
        auto x      = mm1->add_parameter("x", s);
        auto y      = mm1->add_parameter("y", s);
        auto one    = mm1->add_literal(1);
        auto oneb   = mm1->add_instruction(b, one);
        auto two    = mm1->add_literal(2);
        auto twob   = mm1->add_instruction(b, two);
        auto sum1   = mm1->add_instruction(migraphx::op::add{}, x, oneb);
        auto relu1  = mm1->add_instruction(migraphx::op::relu{}, sum1);
        auto sum2   = mm1->add_instruction(migraphx::op::add{}, y, twob);
        auto relu2  = mm1->add_instruction(migraphx::op::relu{}, sum2);
        auto concat = mm1->add_instruction(migraphx::op::concat{1}, relu1, relu2);
        mm1->add_instruction(pass_op{}, concat);
642
643
644
645
646
    }
    run_pass(p1);

    migraphx::program p2;
    {
647
        auto* mm2     = p2.get_main_module();
648
        auto b        = migraphx::op::broadcast{1, {2, 2, 4, 5}};
649
650
651
652
653
654
655
656
657
658
        auto x        = mm2->add_parameter("x", s);
        auto y        = mm2->add_parameter("y", s);
        auto one      = mm2->add_literal(1);
        auto two      = mm2->add_literal(2);
        auto concat1  = mm2->add_instruction(migraphx::op::concat{1}, x, y);
        auto concat2  = mm2->add_instruction(migraphx::op::concat{0}, one, two);
        auto concat2b = mm2->add_instruction(b, concat2);
        auto sum      = mm2->add_instruction(migraphx::op::add{}, concat1, concat2b);
        auto relu     = mm2->add_instruction(migraphx::op::relu{}, sum);
        mm2->add_instruction(pass_op{}, relu);
659
660
661
662
663
664
665
666
667
    }
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
    migraphx::program p1;
    {
668
        auto* mm1   = p1.get_main_module();
669
        auto b      = migraphx::op::broadcast{1, {2, 1, 4, 5}};
670
671
672
673
674
675
676
677
678
679
680
681
        auto x      = mm1->add_parameter("x", s);
        auto y      = mm1->add_parameter("y", s);
        auto one    = mm1->add_literal(1);
        auto oneb   = mm1->add_instruction(b, one);
        auto two    = mm1->add_literal(2);
        auto twob   = mm1->add_instruction(b, two);
        auto sum1   = mm1->add_instruction(migraphx::op::add{}, x, oneb);
        auto relu1  = mm1->add_instruction(migraphx::op::relu{}, sum1);
        auto sum2   = mm1->add_instruction(migraphx::op::add{}, y, twob);
        auto relu2  = mm1->add_instruction(migraphx::op::relu{}, sum2);
        auto concat = mm1->add_instruction(migraphx::op::concat{0}, relu1, relu2);
        mm1->add_instruction(pass_op{}, concat);
682
683
684
685
686
    }
    run_pass(p1);

    migraphx::program p2;
    {
687
        auto* mm2    = p2.get_main_module();
688
        auto b       = migraphx::op::broadcast{1, {2, 1, 4, 5}};
689
690
691
692
693
694
695
696
697
698
699
        auto x       = mm2->add_parameter("x", s);
        auto y       = mm2->add_parameter("y", s);
        auto one     = mm2->add_literal(1);
        auto oneb    = mm2->add_instruction(b, one);
        auto two     = mm2->add_literal(2);
        auto twob    = mm2->add_instruction(b, two);
        auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, x, y);
        auto concat2 = mm2->add_instruction(migraphx::op::concat{0}, oneb, twob);
        auto sum     = mm2->add_instruction(migraphx::op::add{}, concat1, concat2);
        auto relu    = mm2->add_instruction(migraphx::op::relu{}, sum);
        mm2->add_instruction(pass_op{}, relu);
700
701
702
703
    }
    EXPECT(p1 == p2);
}

704
705
706
707
TEST_CASE(simplify_div_const)
{
    migraphx::program p1;
    {
708
709
710
711
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto two  = mm1->add_literal(2);
        mm1->add_instruction(migraphx::op::div{}, x, two);
712
713
714
715
716
    }
    run_pass(p1);

    migraphx::program p2;
    {
717
718
719
720
721
        auto* mm2  = p2.get_main_module();
        auto x     = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto two   = mm2->add_literal(2);
        auto recip = mm2->insert_instruction(std::next(two), migraphx::op::recip{}, two);
        mm2->add_instruction(migraphx::op::mul{}, x, recip);
722
723
724
725
726
727
728
729
    }
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_sub_const)
{
    migraphx::program p1;
    {
730
731
732
733
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto two  = mm1->add_literal(2);
        mm1->add_instruction(migraphx::op::sub{}, x, two);
734
735
736
737
738
    }
    run_pass(p1);

    migraphx::program p2;
    {
739
740
741
742
743
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto two  = mm2->add_literal(2);
        auto neg  = mm2->insert_instruction(std::next(two), migraphx::op::neg{}, two);
        mm2->add_instruction(migraphx::op::add{}, x, neg);
744
745
746
747
    }
    EXPECT(p1 == p2);
}

kahmed10's avatar
kahmed10 committed
748
749
750
751
TEST_CASE(simplify_rsqrt)
{
    migraphx::program p1;
    {
752
753
754
755
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto sqrt = mm1->add_instruction(migraphx::op::sqrt{}, x);
        mm1->add_instruction(migraphx::op::recip{}, sqrt);
kahmed10's avatar
kahmed10 committed
756
757
758
759
760
    }
    run_pass(p1);

    migraphx::program p2;
    {
761
762
763
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        mm2->add_instruction(migraphx::op::rsqrt{}, x);
kahmed10's avatar
kahmed10 committed
764
765
766
767
768
769
770
771
    }
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_rsqrt_multi_use)
{
    migraphx::program p1;
    {
772
773
774
775
776
777
        auto* mm1  = p1.get_main_module();
        auto x     = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto sqrt  = mm1->add_instruction(migraphx::op::sqrt{}, x);
        auto add   = mm1->add_instruction(migraphx::op::add{}, sqrt, sqrt);
        auto rsqrt = mm1->add_instruction(migraphx::op::recip{}, sqrt);
        mm1->add_instruction(migraphx::op::add{}, rsqrt, add);
kahmed10's avatar
kahmed10 committed
778
779
780
781
782
783
784
    }
    migraphx::program p2{p1};

    run_pass(p1);
    EXPECT(p1 == p2);
}

785
786
787
788
789
790
TEST_CASE(simplify_slice_concat)
{
    auto s = migraphx::shape{migraphx::shape::float_type, {256}};

    migraphx::program p1;
    {
791
792
793
794
795
796
797
        auto* mm1    = p1.get_main_module();
        auto x       = mm1->add_parameter("x", s);
        auto y       = mm1->add_parameter("y", s);
        auto xslice1 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {128}}, x);
        auto xslice2 = mm1->add_instruction(migraphx::op::slice{{0}, {128}, {256}}, x);
        auto yslice1 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {128}}, y);
        auto yslice2 = mm1->add_instruction(migraphx::op::slice{{0}, {128}, {256}}, y);
798
        auto concat =
799
800
            mm1->add_instruction(migraphx::op::concat{0}, xslice1, xslice2, yslice1, yslice2);
        mm1->add_instruction(pass_op{}, concat);
801
802
803
804
805
    }
    run_pass(p1);

    migraphx::program p2;
    {
806
807
808
809
810
        auto* mm2   = p2.get_main_module();
        auto x      = mm2->add_parameter("x", s);
        auto y      = mm2->add_parameter("y", s);
        auto concat = mm2->add_instruction(migraphx::op::concat{0}, x, y);
        mm2->add_instruction(pass_op{}, concat);
811
812
813
814
815
816
817
818
819
820
    }
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_slice_concat_non_uniform)
{
    auto s = migraphx::shape{migraphx::shape::float_type, {256}};

    migraphx::program p1;
    {
821
822
823
824
825
826
827
828
829
830
        auto* mm1    = p1.get_main_module();
        auto x       = mm1->add_parameter("x", s);
        auto y       = mm1->add_parameter("y", s);
        auto xslice1 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {64}}, x);
        auto xslice2 = mm1->add_instruction(migraphx::op::slice{{0}, {64}, {192}}, x);
        auto xslice3 = mm1->add_instruction(migraphx::op::slice{{0}, {192}, {256}}, x);
        auto yslice1 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {64}}, y);
        auto yslice2 = mm1->add_instruction(migraphx::op::slice{{0}, {64}, {192}}, y);
        auto yslice3 = mm1->add_instruction(migraphx::op::slice{{0}, {192}, {256}}, y);
        auto concat  = mm1->add_instruction(
831
            migraphx::op::concat{0}, xslice1, xslice2, xslice3, yslice1, yslice2, yslice3);
832
        mm1->add_instruction(pass_op{}, concat);
833
834
835
836
837
    }
    run_pass(p1);

    migraphx::program p2;
    {
838
839
840
841
842
        auto* mm2   = p2.get_main_module();
        auto x      = mm2->add_parameter("x", s);
        auto y      = mm2->add_parameter("y", s);
        auto concat = mm2->add_instruction(migraphx::op::concat{0}, x, y);
        mm2->add_instruction(pass_op{}, concat);
843
844
845
846
847
848
849
850
851
852
853
    }

    EXPECT(p1 == p2);
}

TEST_CASE(simplify_slice_concat_flipped)
{
    auto s = migraphx::shape{migraphx::shape::float_type, {256}};

    migraphx::program p1;
    {
854
855
856
857
858
859
860
861
862
863
        auto* mm1    = p1.get_main_module();
        auto x       = mm1->add_parameter("x", s);
        auto y       = mm1->add_parameter("y", s);
        auto xslice1 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {64}}, x);
        auto xslice2 = mm1->add_instruction(migraphx::op::slice{{0}, {192}, {256}}, x);
        auto xslice3 = mm1->add_instruction(migraphx::op::slice{{0}, {64}, {192}}, x);
        auto yslice1 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {64}}, y);
        auto yslice2 = mm1->add_instruction(migraphx::op::slice{{0}, {192}, {256}}, y);
        auto yslice3 = mm1->add_instruction(migraphx::op::slice{{0}, {64}, {192}}, y);
        auto concat  = mm1->add_instruction(
864
            migraphx::op::concat{0}, xslice1, xslice2, xslice3, yslice1, yslice2, yslice3);
865
        mm1->add_instruction(pass_op{}, concat);
866
867
868
869
870
871
872
    }
    migraphx::program p2 = p1;
    run_pass(p1);

    EXPECT(p1 == p2);
}

873
874
875
876
877
TEST_CASE(simplify_split_add_relu)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
    migraphx::program p1;
    {
878
        auto* mm1  = p1.get_main_module();
879
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
880
881
882
883
884
885
886
887
888
889
890
891
892
        auto input = mm1->add_parameter("input", s);
        auto x     = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
        auto y     = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
        auto sum1  = mm1->add_instruction(migraphx::op::add{}, x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
        auto sum2  = mm1->add_instruction(migraphx::op::add{}, y, twob);
        auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
        auto add   = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
        mm1->add_instruction(pass_op{}, add);
893
894
895
896
897
    }
    run_pass(p1);

    migraphx::program p2;
    {
898
        auto* mm2    = p2.get_main_module();
899
        auto b       = migraphx::op::broadcast{1, {3, 2, 4}};
900
901
902
903
904
905
906
907
908
909
910
        auto input   = mm2->add_parameter("input", s);
        auto one     = mm2->add_literal(1);
        auto two     = mm2->add_literal(2);
        auto concat  = mm2->add_instruction(migraphx::op::concat{0}, one, two);
        auto concatb = mm2->add_instruction(b, concat);
        auto sum     = mm2->add_instruction(migraphx::op::add{}, input, concatb);
        auto relu    = mm2->add_instruction(migraphx::op::relu{}, sum);
        auto x       = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu);
        auto y       = mm2->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu);
        auto add     = mm2->add_instruction(migraphx::op::add{}, x, y);
        mm2->add_instruction(pass_op{}, add);
911
912
913
914
915
916
917
918
919
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_split_add_relu_reshape)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
    migraphx::program p1;
    {
920
        auto* mm1     = p1.get_main_module();
921
922
        auto b        = migraphx::op::broadcast{1, {3, 1, 4}};
        auto r        = migraphx::op::reshape{{3, 4}};
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
        auto input    = mm1->add_parameter("input", s);
        auto x        = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
        auto y        = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
        auto one      = mm1->add_literal(1);
        auto oneb     = mm1->add_instruction(b, one);
        auto two      = mm1->add_literal(2);
        auto twob     = mm1->add_instruction(b, two);
        auto sum1     = mm1->add_instruction(migraphx::op::add{}, x, oneb);
        auto relu1    = mm1->add_instruction(migraphx::op::relu{}, sum1);
        auto reshape1 = mm1->add_instruction(r, relu1);
        auto sum2     = mm1->add_instruction(migraphx::op::add{}, y, twob);
        auto relu2    = mm1->add_instruction(migraphx::op::relu{}, sum2);
        auto reshape2 = mm1->add_instruction(r, relu2);
        auto add      = mm1->add_instruction(migraphx::op::add{}, reshape1, reshape2);
        mm1->add_instruction(pass_op{}, add);
938
939
940
941
942
    }
    run_pass(p1);

    migraphx::program p2;
    {
943
        auto* mm2    = p2.get_main_module();
944
        auto b       = migraphx::op::broadcast{1, {3, 2, 4}};
945
946
947
948
949
950
951
952
953
954
955
956
        auto input   = mm2->add_parameter("input", s);
        auto one     = mm2->add_literal(1);
        auto two     = mm2->add_literal(2);
        auto concat  = mm2->add_instruction(migraphx::op::concat{0}, one, two);
        auto concatb = mm2->add_instruction(b, concat);
        auto sum     = mm2->add_instruction(migraphx::op::add{}, input, concatb);
        auto relu    = mm2->add_instruction(migraphx::op::relu{}, sum);
        auto rsp     = mm2->add_instruction(migraphx::op::reshape{{3, 8}}, relu);
        auto slc1    = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {4}}, rsp);
        auto slc2    = mm2->add_instruction(migraphx::op::slice{{1}, {4}, {8}}, rsp);
        auto add     = mm2->add_instruction(migraphx::op::add{}, slc1, slc2);
        mm2->add_instruction(pass_op{}, add);
957
958
959
960
961
962
963
964
965
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_slice_different_axis)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 2}};
    migraphx::program p1;
    {
966
        auto* mm1     = p1.get_main_module();
967
        auto r        = migraphx::op::reshape{{3, 2, 4}};
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
        auto input    = mm1->add_parameter("input", s);
        auto x        = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
        auto y        = mm1->add_instruction(migraphx::op::slice{{3}, {0}, {1}}, input);
        auto one      = mm1->add_literal(1);
        auto oneb     = mm1->add_instruction(migraphx::op::broadcast{1, {3, 1, 4, 2}}, one);
        auto two      = mm1->add_literal(2);
        auto twob     = mm1->add_instruction(migraphx::op::broadcast{3, {3, 2, 4, 1}}, two);
        auto sum1     = mm1->add_instruction(migraphx::op::add{}, x, oneb);
        auto relu1    = mm1->add_instruction(migraphx::op::relu{}, sum1);
        auto reshape1 = mm1->add_instruction(r, relu1);
        auto sum2     = mm1->add_instruction(migraphx::op::add{}, y, twob);
        auto relu2    = mm1->add_instruction(migraphx::op::relu{}, sum2);
        auto reshape2 = mm1->add_instruction(r, relu2);
        auto add      = mm1->add_instruction(migraphx::op::add{}, reshape1, reshape2);
        mm1->add_instruction(pass_op{}, add);
983
984
985
986
987
988
989
990
991
992
993
994
    }
    migraphx::program p2 = p1;
    run_pass(p1);

    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_slice_missing_begining_slice)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
    migraphx::program p1;
    {
995
        auto* mm1  = p1.get_main_module();
996
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
        auto input = mm1->add_parameter("input", s);
        auto x     = mm1->add_instruction(migraphx::op::slice{{1}, {2}, {3}}, input);
        auto y     = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
        auto sum1  = mm1->add_instruction(migraphx::op::add{}, x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
        auto sum2  = mm1->add_instruction(migraphx::op::add{}, y, twob);
        auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
        auto add   = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
        mm1->add_instruction(pass_op{}, add);
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
    }
    migraphx::program p2 = p1;
    run_pass(p1);

    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_slice_missing_middle_slice)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
    migraphx::program p1;
    {
1022
        auto* mm1  = p1.get_main_module();
1023
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
        auto input = mm1->add_parameter("input", s);
        auto x     = mm1->add_instruction(migraphx::op::slice{{1}, {2}, {3}}, input);
        auto y     = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
        auto sum1  = mm1->add_instruction(migraphx::op::add{}, x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
        auto sum2  = mm1->add_instruction(migraphx::op::add{}, y, twob);
        auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
        auto add   = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
        mm1->add_instruction(pass_op{}, add);
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
    }
    migraphx::program p2 = p1;
    run_pass(p1);

    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_slice_missing_end_slice)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
    migraphx::program p1;
    {
1049
        auto* mm1  = p1.get_main_module();
1050
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
        auto input = mm1->add_parameter("input", s);
        auto x     = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
        auto y     = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
        auto sum1  = mm1->add_instruction(migraphx::op::add{}, x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
        auto sum2  = mm1->add_instruction(migraphx::op::add{}, y, twob);
        auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
        auto add   = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
        mm1->add_instruction(pass_op{}, add);
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
    }
    migraphx::program p2 = p1;
    run_pass(p1);

    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_split_add_relu_concat_same_axis)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
    migraphx::program p1;
    {
1076
        auto* mm1   = p1.get_main_module();
1077
        auto b      = migraphx::op::broadcast{1, {3, 1, 4}};
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
        auto input  = mm1->add_parameter("input", s);
        auto x      = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
        auto y      = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
        auto one    = mm1->add_literal(1);
        auto oneb   = mm1->add_instruction(b, one);
        auto two    = mm1->add_literal(2);
        auto twob   = mm1->add_instruction(b, two);
        auto sum1   = mm1->add_instruction(migraphx::op::add{}, x, oneb);
        auto relu1  = mm1->add_instruction(migraphx::op::relu{}, sum1);
        auto sum2   = mm1->add_instruction(migraphx::op::add{}, y, twob);
        auto relu2  = mm1->add_instruction(migraphx::op::relu{}, sum2);
        auto concat = mm1->add_instruction(migraphx::op::concat{1}, relu1, relu2);
        mm1->add_instruction(pass_op{}, concat);
1091
1092
1093
1094
1095
    }
    run_pass(p1);

    migraphx::program p2;
    {
1096
        auto* mm2    = p2.get_main_module();
1097
        auto b       = migraphx::op::broadcast{1, {3, 2, 4}};
1098
1099
1100
1101
1102
1103
1104
1105
        auto input   = mm2->add_parameter("input", s);
        auto one     = mm2->add_literal(1);
        auto two     = mm2->add_literal(2);
        auto concat  = mm2->add_instruction(migraphx::op::concat{0}, one, two);
        auto concatb = mm2->add_instruction(b, concat);
        auto sum     = mm2->add_instruction(migraphx::op::add{}, input, concatb);
        auto relu    = mm2->add_instruction(migraphx::op::relu{}, sum);
        mm2->add_instruction(pass_op{}, relu);
1106
1107
1108
1109
1110
1111
1112
1113
1114
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_split_add_relu_multi_axes)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 6}};
    migraphx::program p1;
    {
1115
        auto* mm1  = p1.get_main_module();
1116
        auto b     = migraphx::op::broadcast{1, {3, 1, 4, 3}};
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
        auto input = mm1->add_parameter("input", s);
        auto x     = mm1->add_instruction(migraphx::op::slice{{1, 3}, {0, 0}, {1, 3}}, input);
        auto y     = mm1->add_instruction(migraphx::op::slice{{1, 3}, {1, 3}, {2, 6}}, input);
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
        auto sum1  = mm1->add_instruction(migraphx::op::add{}, x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
        auto sum2  = mm1->add_instruction(migraphx::op::add{}, y, twob);
        auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
        auto add   = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
        mm1->add_instruction(pass_op{}, add);
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
    }
    migraphx::program p2 = p1;
    run_pass(p1);
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_split_add_relu_used_multiple_split1)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
    migraphx::program p1;
    {
1141
        auto* mm1  = p1.get_main_module();
1142
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
        auto input = mm1->add_parameter("input", s);
        auto x     = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
        auto y     = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
        auto sum1  = mm1->add_instruction(migraphx::op::add{}, x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
        auto sum2  = mm1->add_instruction(migraphx::op::add{}, y, twob);
        auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
        auto add1  = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
        auto add2  = mm1->add_instruction(migraphx::op::add{}, x, add1);
        mm1->add_instruction(pass_op{}, add2);
1157
1158
1159
1160
1161
    }
    run_pass(p1);

    migraphx::program p2;
    {
1162
        auto* mm2    = p2.get_main_module();
1163
        auto b       = migraphx::op::broadcast{1, {3, 2, 4}};
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
        auto input   = mm2->add_parameter("input", s);
        auto slice   = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
        auto one     = mm2->add_literal(1);
        auto two     = mm2->add_literal(2);
        auto concat  = mm2->add_instruction(migraphx::op::concat{0}, one, two);
        auto concatb = mm2->add_instruction(b, concat);
        auto sum     = mm2->add_instruction(migraphx::op::add{}, input, concatb);
        auto relu    = mm2->add_instruction(migraphx::op::relu{}, sum);
        auto x       = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu);
        auto y       = mm2->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu);
        auto add1    = mm2->add_instruction(migraphx::op::add{}, x, y);
        auto add2    = mm2->add_instruction(migraphx::op::add{}, slice, add1);
        mm2->add_instruction(pass_op{}, add2);
1177
1178
1179
1180
1181
1182
1183
1184
1185
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_split_add_relu_used_multiple_split2)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
    migraphx::program p1;
    {
1186
        auto* mm1  = p1.get_main_module();
1187
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
        auto input = mm1->add_parameter("input", s);
        auto x     = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
        auto y     = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
        auto z     = mm1->add_instruction(migraphx::op::relu{}, x);
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
        auto sum1  = mm1->add_instruction(migraphx::op::add{}, x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
        auto sum2  = mm1->add_instruction(migraphx::op::add{}, y, twob);
        auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
        auto add1  = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
        auto add2  = mm1->add_instruction(migraphx::op::add{}, z, add1);
        mm1->add_instruction(pass_op{}, add2);
1203
1204
1205
1206
1207
    }
    run_pass(p1);

    migraphx::program p2;
    {
1208
        auto* mm2    = p2.get_main_module();
1209
        auto b       = migraphx::op::broadcast{1, {3, 2, 4}};
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
        auto input   = mm2->add_parameter("input", s);
        auto slice   = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
        auto z       = mm2->add_instruction(migraphx::op::relu{}, slice);
        auto one     = mm2->add_literal(1);
        auto two     = mm2->add_literal(2);
        auto concat  = mm2->add_instruction(migraphx::op::concat{0}, one, two);
        auto concatb = mm2->add_instruction(b, concat);
        auto sum     = mm2->add_instruction(migraphx::op::add{}, input, concatb);
        auto relu    = mm2->add_instruction(migraphx::op::relu{}, sum);
        auto x       = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu);
        auto y       = mm2->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu);
        auto add1    = mm2->add_instruction(migraphx::op::add{}, x, y);
        auto add2    = mm2->add_instruction(migraphx::op::add{}, z, add1);
        mm2->add_instruction(pass_op{}, add2);
1224
1225
1226
1227
1228
1229
1230
1231
1232
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_split_between_add)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
    migraphx::program p1;
    {
1233
1234
1235
1236
1237
1238
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto x     = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
        auto y     = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
        auto sum   = mm1->add_instruction(migraphx::op::add{}, x, y);
        mm1->add_instruction(pass_op{}, sum);
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
    }
    migraphx::program p2 = p1;
    run_pass(p1);
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_dot_horiz)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
    migraphx::program p1;
    {
1250
1251
1252
1253
1254
1255
1256
1257
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto a     = mm1->add_literal(migraphx::generate_literal(s, 0));
        auto b     = mm1->add_literal(migraphx::generate_literal(s, 1));
        auto x     = mm1->add_instruction(migraphx::op::dot{}, input, a);
        auto y     = mm1->add_instruction(migraphx::op::dot{}, input, b);
        auto sum   = mm1->add_instruction(migraphx::op::add{}, x, y);
        mm1->add_instruction(pass_op{}, sum);
1258
1259
1260
1261
1262
    }
    run_pass(p1);

    migraphx::program p2;
    {
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
        auto* mm2   = p2.get_main_module();
        auto input  = mm2->add_parameter("input", s);
        auto a      = mm2->add_literal(migraphx::generate_literal(s, 0));
        auto b      = mm2->add_literal(migraphx::generate_literal(s, 1));
        auto concat = mm2->add_instruction(migraphx::op::concat{2}, a, b);
        auto dot    = mm2->add_instruction(migraphx::op::dot{}, input, concat);
        auto x      = mm2->add_instruction(migraphx::op::slice{{2}, {0}, {2}}, dot);
        auto y      = mm2->add_instruction(migraphx::op::slice{{2}, {2}, {4}}, dot);
        auto sum    = mm2->add_instruction(migraphx::op::add{}, x, y);
        mm2->add_instruction(pass_op{}, sum);
1273
1274
1275
1276
1277
1278
1279
1280
1281
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_dot_horiz_same_constant)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
    migraphx::program p1;
    {
1282
1283
1284
1285
1286
1287
1288
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto a     = mm1->add_literal(migraphx::generate_literal(s, 0));
        auto x     = mm1->add_instruction(migraphx::op::dot{}, input, a);
        auto y     = mm1->add_instruction(migraphx::op::dot{}, input, a);
        auto sum   = mm1->add_instruction(migraphx::op::add{}, x, y);
        mm1->add_instruction(pass_op{}, sum);
1289
1290
1291
1292
1293
    }
    run_pass(p1);

    migraphx::program p2;
    {
1294
1295
1296
1297
1298
1299
1300
1301
1302
        auto* mm2   = p2.get_main_module();
        auto input  = mm2->add_parameter("input", s);
        auto a      = mm2->add_literal(migraphx::generate_literal(s, 0));
        auto concat = mm2->add_instruction(migraphx::op::concat{2}, a, a);
        auto dot    = mm2->add_instruction(migraphx::op::dot{}, input, concat);
        auto x      = mm2->add_instruction(migraphx::op::slice{{2}, {0}, {2}}, dot);
        auto y      = mm2->add_instruction(migraphx::op::slice{{2}, {2}, {4}}, dot);
        auto sum    = mm2->add_instruction(migraphx::op::add{}, x, y);
        mm2->add_instruction(pass_op{}, sum);
1303
1304
1305
1306
1307
1308
1309
1310
1311
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_dot_horiz_flipped)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
    migraphx::program p1;
    {
1312
1313
1314
1315
1316
1317
1318
1319
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto a     = mm1->add_literal(migraphx::generate_literal(s, 0));
        auto b     = mm1->add_literal(migraphx::generate_literal(s, 1));
        auto x     = mm1->add_instruction(migraphx::op::dot{}, input, a);
        auto y     = mm1->add_instruction(migraphx::op::dot{}, b, input);
        auto sum   = mm1->add_instruction(migraphx::op::add{}, x, y);
        mm1->add_instruction(pass_op{}, sum);
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
    }

    migraphx::program p2 = p1;
    run_pass(p1);
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_conv_horiz)
{
    auto s  = migraphx::shape{migraphx::shape::int32_type, {8, 3, 64, 64}};
    auto ws = migraphx::shape{migraphx::shape::int32_type, {12, 3, 3, 3}};
    migraphx::program p1;
    {
1333
1334
1335
1336
1337
1338
1339
1340
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto a     = mm1->add_literal(migraphx::generate_literal(ws, 0));
        auto b     = mm1->add_literal(migraphx::generate_literal(ws, 1));
        auto x     = mm1->add_instruction(migraphx::op::convolution{}, input, a);
        auto y     = mm1->add_instruction(migraphx::op::convolution{}, input, b);
        auto sum   = mm1->add_instruction(migraphx::op::add{}, x, y);
        mm1->add_instruction(pass_op{}, sum);
1341
1342
1343
1344
1345
    }
    run_pass(p1);

    migraphx::program p2;
    {
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
        auto* mm2   = p2.get_main_module();
        auto input  = mm2->add_parameter("input", s);
        auto a      = mm2->add_literal(migraphx::generate_literal(ws, 0));
        auto b      = mm2->add_literal(migraphx::generate_literal(ws, 1));
        auto concat = mm2->add_instruction(migraphx::op::concat{0}, a, b);
        auto conv   = mm2->add_instruction(migraphx::op::convolution{}, input, concat);
        auto x      = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {12}}, conv);
        auto y      = mm2->add_instruction(migraphx::op::slice{{1}, {12}, {24}}, conv);
        auto sum    = mm2->add_instruction(migraphx::op::add{}, x, y);
        mm2->add_instruction(pass_op{}, sum);
1356
1357
1358
1359
    }
    EXPECT(p1.sort() == p2.sort());
}

1360
1361
1362
1363
1364
1365
TEST_CASE(simplify_group_conv_horiz)
{
    auto s  = migraphx::shape{migraphx::shape::int32_type, {1, 32, 111, 111}};
    auto ws = migraphx::shape{migraphx::shape::int32_type, {32, 1, 7, 7}};
    migraphx::program p1;
    {
1366
1367
1368
1369
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", s);
        auto w1   = mm1->add_literal(migraphx::generate_literal(ws, 1));
        auto w2   = mm1->add_literal(migraphx::generate_literal(ws, 2));
1370
        auto conv1 =
1371
            mm1->add_instruction(migraphx::op::convolution{{3, 3}, {2, 2}, {1, 1}, 32}, x, w1);
1372
        auto conv2 =
1373
1374
            mm1->add_instruction(migraphx::op::convolution{{3, 3}, {2, 2}, {1, 1}, 32}, x, w2);
        mm1->add_instruction(pass_op{}, conv1, conv2);
1375
1376
1377
1378
1379
1380
1381
1382
    }
    migraphx::program p2 = p1;
    run_pass(p1);

    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_conv_horiz_grouped)
1383
1384
1385
1386
1387
1388
{
    auto s   = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
    auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}};
    auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
    migraphx::program p1;
    {
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto a     = mm1->add_literal(migraphx::generate_literal(ws1, 0));
        auto b     = mm1->add_literal(migraphx::generate_literal(ws1, 1));
        auto c     = mm1->add_literal(migraphx::generate_literal(ws2, 2));
        auto d     = mm1->add_literal(migraphx::generate_literal(ws2, 3));
        auto convx = mm1->add_instruction(migraphx::op::convolution{{1, 1}}, input, a);
        auto convy = mm1->add_instruction(migraphx::op::convolution{{1, 1}}, input, b);
        auto dotx  = mm1->add_instruction(migraphx::op::dot{}, input, c);
        auto doty  = mm1->add_instruction(migraphx::op::dot{}, input, d);
        auto sum1  = mm1->add_instruction(migraphx::op::add{}, convx, convy);
        auto sum2  = mm1->add_instruction(migraphx::op::add{}, dotx, doty);
        auto sum3  = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);

        mm1->add_instruction(pass_op{}, sum3);
1404
1405
1406
1407
1408
    }
    run_pass(p1);

    migraphx::program p2;
    {
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
        auto* mm2    = p2.get_main_module();
        auto input   = mm2->add_parameter("input", s);
        auto a       = mm2->add_literal(migraphx::generate_literal(ws1, 0));
        auto b       = mm2->add_literal(migraphx::generate_literal(ws1, 1));
        auto c       = mm2->add_literal(migraphx::generate_literal(ws2, 2));
        auto d       = mm2->add_literal(migraphx::generate_literal(ws2, 3));
        auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, a, b);
        auto concat2 = mm2->add_instruction(migraphx::op::concat{3}, c, d);
        auto conv    = mm2->add_instruction(migraphx::op::convolution{{1, 1}}, input, concat1);
        auto convx   = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {6}}, conv);
        auto convy   = mm2->add_instruction(migraphx::op::slice{{1}, {6}, {12}}, conv);
        auto sum1    = mm2->add_instruction(migraphx::op::add{}, convx, convy);
        auto dot     = mm2->add_instruction(migraphx::op::dot{}, input, concat2);
        auto dotx    = mm2->add_instruction(migraphx::op::slice{{3}, {0}, {64}}, dot);
        auto doty    = mm2->add_instruction(migraphx::op::slice{{3}, {64}, {128}}, dot);
        auto sum2    = mm2->add_instruction(migraphx::op::add{}, dotx, doty);
        auto sum3    = mm2->add_instruction(migraphx::op::add{}, sum1, sum2);
        mm2->add_instruction(pass_op{}, sum3);
1427
1428
1429
1430
    }
    EXPECT(p1.sort() == p2.sort());
}

1431
TEST_CASE(simplify_conv_horiz_grouped_extra1)
1432
1433
1434
1435
1436
1437
{
    auto s   = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
    auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}};
    auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
    migraphx::program p1;
    {
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
        auto* mm1    = p1.get_main_module();
        auto input   = mm1->add_parameter("input", s);
        auto a       = mm1->add_literal(migraphx::generate_literal(ws1, 0));
        auto b       = mm1->add_literal(migraphx::generate_literal(ws1, 1));
        auto c       = mm1->add_literal(migraphx::generate_literal(ws2, 2));
        auto d       = mm1->add_literal(migraphx::generate_literal(ws2, 3));
        auto e       = mm1->add_literal(migraphx::generate_literal(s, 4));
        auto convx   = mm1->add_instruction(migraphx::op::convolution{{1, 1}}, input, a);
        auto convy   = mm1->add_instruction(migraphx::op::convolution{{1, 1}}, input, b);
        auto dotx    = mm1->add_instruction(migraphx::op::dot{}, input, c);
        auto doty    = mm1->add_instruction(migraphx::op::dot{}, input, d);
        auto sqdiffx = mm1->add_instruction(migraphx::op::sqdiff{}, input, e);
        auto sum1    = mm1->add_instruction(migraphx::op::add{}, convx, convy);
        auto sum2    = mm1->add_instruction(migraphx::op::add{}, dotx, doty);
1452
        auto sum3    = sqdiffx;
1453
1454
1455
        auto sum4    = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
        auto sum5    = mm1->add_instruction(migraphx::op::add{}, sum4, sum3);
        mm1->add_instruction(pass_op{}, sum5);
1456
1457
1458
1459
1460
    }
    run_pass(p1);

    migraphx::program p2;
    {
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
        auto* mm2    = p2.get_main_module();
        auto input   = mm2->add_parameter("input", s);
        auto a       = mm2->add_literal(migraphx::generate_literal(ws1, 0));
        auto b       = mm2->add_literal(migraphx::generate_literal(ws1, 1));
        auto c       = mm2->add_literal(migraphx::generate_literal(ws2, 2));
        auto d       = mm2->add_literal(migraphx::generate_literal(ws2, 3));
        auto e       = mm2->add_literal(migraphx::generate_literal(s, 4));
        auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, a, b);
        auto concat2 = mm2->add_instruction(migraphx::op::concat{3}, c, d);
        auto conv    = mm2->add_instruction(migraphx::op::convolution{{1, 1}}, input, concat1);
        auto convx   = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {6}}, conv);
        auto convy   = mm2->add_instruction(migraphx::op::slice{{1}, {6}, {12}}, conv);
        auto sum1    = mm2->add_instruction(migraphx::op::add{}, convx, convy);
        auto dot     = mm2->add_instruction(migraphx::op::dot{}, input, concat2);
        auto dotx    = mm2->add_instruction(migraphx::op::slice{{3}, {0}, {64}}, dot);
        auto doty    = mm2->add_instruction(migraphx::op::slice{{3}, {64}, {128}}, dot);
        auto sum2    = mm2->add_instruction(migraphx::op::add{}, dotx, doty);
        auto sqdiffx = mm2->add_instruction(migraphx::op::sqdiff{}, input, e);
1479
        auto sum3    = sqdiffx;
1480
1481
1482
        auto sum4    = mm2->add_instruction(migraphx::op::add{}, sum1, sum2);
        auto sum5    = mm2->add_instruction(migraphx::op::add{}, sum4, sum3);
        mm2->add_instruction(pass_op{}, sum5);
1483
1484
1485
1486
    }
    EXPECT(p1.sort() == p2.sort());
}

1487
TEST_CASE(simplify_conv_horiz_grouped_extra2)
1488
1489
1490
1491
1492
1493
{
    auto s   = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
    auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}};
    auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
    migraphx::program p1;
    {
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
        auto* mm1    = p1.get_main_module();
        auto input   = mm1->add_parameter("input", s);
        auto a       = mm1->add_literal(migraphx::generate_literal(ws1, 0));
        auto b       = mm1->add_literal(migraphx::generate_literal(ws1, 1));
        auto c       = mm1->add_literal(migraphx::generate_literal(ws2, 2));
        auto d       = mm1->add_literal(migraphx::generate_literal(ws2, 3));
        auto e       = mm1->add_literal(migraphx::generate_literal(s, 4));
        auto f       = mm1->add_literal(migraphx::generate_literal(s, 5));
        auto convx   = mm1->add_instruction(migraphx::op::convolution{{1, 1}}, input, a);
        auto convy   = mm1->add_instruction(migraphx::op::convolution{{1, 1}}, input, b);
        auto dotx    = mm1->add_instruction(migraphx::op::dot{}, input, c);
        auto doty    = mm1->add_instruction(migraphx::op::dot{}, input, d);
        auto sqdiffx = mm1->add_instruction(migraphx::op::sqdiff{}, input, e);
        auto sqdiffy = mm1->add_instruction(migraphx::op::sqdiff{}, input, f);
        auto sum1    = mm1->add_instruction(migraphx::op::add{}, convx, convy);
        auto sum2    = mm1->add_instruction(migraphx::op::add{}, dotx, doty);
        auto sum3    = mm1->add_instruction(migraphx::op::add{}, sqdiffx, sqdiffy);
        auto sum4    = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
        auto sum5    = mm1->add_instruction(migraphx::op::add{}, sum4, sum3);
        mm1->add_instruction(pass_op{}, sum5);
1514
1515
1516
1517
1518
    }
    run_pass(p1);

    migraphx::program p2;
    {
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
        auto* mm2    = p2.get_main_module();
        auto input   = mm2->add_parameter("input", s);
        auto a       = mm2->add_literal(migraphx::generate_literal(ws1, 0));
        auto b       = mm2->add_literal(migraphx::generate_literal(ws1, 1));
        auto c       = mm2->add_literal(migraphx::generate_literal(ws2, 2));
        auto d       = mm2->add_literal(migraphx::generate_literal(ws2, 3));
        auto e       = mm2->add_literal(migraphx::generate_literal(s, 4));
        auto f       = mm2->add_literal(migraphx::generate_literal(s, 5));
        auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, a, b);
        auto concat2 = mm2->add_instruction(migraphx::op::concat{3}, c, d);
        auto conv    = mm2->add_instruction(migraphx::op::convolution{{1, 1}}, input, concat1);
        auto convx   = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {6}}, conv);
        auto convy   = mm2->add_instruction(migraphx::op::slice{{1}, {6}, {12}}, conv);
        auto sum1    = mm2->add_instruction(migraphx::op::add{}, convx, convy);
        auto dot     = mm2->add_instruction(migraphx::op::dot{}, input, concat2);
        auto dotx    = mm2->add_instruction(migraphx::op::slice{{3}, {0}, {64}}, dot);
        auto doty    = mm2->add_instruction(migraphx::op::slice{{3}, {64}, {128}}, dot);
        auto sum2    = mm2->add_instruction(migraphx::op::add{}, dotx, doty);
        auto sqdiffx = mm2->add_instruction(migraphx::op::sqdiff{}, input, e);
        auto sqdiffy = mm2->add_instruction(migraphx::op::sqdiff{}, input, f);
        auto sum3    = mm2->add_instruction(migraphx::op::add{}, sqdiffx, sqdiffy);
        auto sum4    = mm2->add_instruction(migraphx::op::add{}, sum1, sum2);
        auto sum5    = mm2->add_instruction(migraphx::op::add{}, sum4, sum3);
        mm2->add_instruction(pass_op{}, sum5);
1543
1544
1545
1546
    }
    EXPECT(p1.sort() == p2.sort());
}

1547
1548
1549
1550
TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
{
    migraphx::program p1;
    {
1551
1552
1553
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
        auto w    = mm1->add_literal(
1554
            migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
1555
1556
        auto conv   = mm1->add_instruction(migraphx::op::convolution{}, x, w);
        auto slice1 = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
1557
        auto a1 =
1558
1559
1560
            mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
        auto b1  = mm1->add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a1);
        auto mul = mm1->add_instruction(migraphx::op::mul{}, slice1, b1);
1561
        auto a2 =
1562
1563
1564
            mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2));
        auto b2   = mm1->add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a2);
        auto add1 = mm1->add_instruction(migraphx::op::add{}, mul, b2);
1565
        auto a3 =
1566
1567
1568
1569
1570
            mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
        auto b3     = mm1->add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a3);
        auto slice2 = mm1->add_instruction(migraphx::op::slice{{1}, {384}, {768}}, conv);
        auto add2   = mm1->add_instruction(migraphx::op::add{}, slice2, b3);
        mm1->add_instruction(pass_op{}, add1, add2);
1571
1572
1573
1574
1575
    }
    run_pass(p1);

    migraphx::program p2;
    {
1576
1577
1578
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
        auto w    = mm2->add_literal(
1579
            migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
1580
        auto wslice1 = mm2->add_instruction(migraphx::op::slice{{0}, {0}, {384}}, w);
1581
        auto a1 =
1582
1583
1584
1585
1586
1587
            mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
        auto b1      = mm2->add_instruction(migraphx::op::broadcast{0, {384, 1024, 1, 1}}, a1);
        auto mul     = mm2->add_instruction(migraphx::op::mul{}, b1, wslice1);
        auto wslice2 = mm2->add_instruction(migraphx::op::slice{{0}, {384}, {768}}, w);
        auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, mul, wslice2);
        auto conv    = mm2->add_instruction(migraphx::op::convolution{}, x, concat1);
1588
        auto a2 =
1589
            mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2));
1590
        auto a3 =
1591
1592
1593
1594
1595
1596
1597
            mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
        auto concat2 = mm2->add_instruction(migraphx::op::concat{}, a2, a3);
        auto b4      = mm2->add_instruction(migraphx::op::broadcast{1, {1, 768, 17, 17}}, concat2);
        auto add     = mm2->add_instruction(migraphx::op::add{}, conv, b4);
        auto slice1  = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {384}}, add);
        auto slice2  = mm2->add_instruction(migraphx::op::slice{{1}, {384}, {768}}, add);
        mm2->add_instruction(pass_op{}, slice1, slice2);
1598
1599
1600
    }
    EXPECT(p1.sort() == p2.sort());
}
1601
1602
1603
1604
1605
1606
TEST_CASE(reorder_reshape_slice)
{
    std::vector<int64_t> perm0 = {0, 2, 1, 3};
    std::vector<int64_t> perm1 = {0, 2, 3, 1};
    auto create_p1             = [&](std::size_t batch_size) {
        migraphx::program p1;
1607
        auto* mm1  = p1.get_main_module();
1608
        auto s     = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
1609
1610
1611
1612
        auto input = mm1->add_parameter("input", s);
        auto slc0  = mm1->add_instruction(migraphx::op::slice{{2}, {0}, {640}}, input);
        auto slc1  = mm1->add_instruction(migraphx::op::slice{{2}, {640}, {1280}}, input);
        auto slc2  = mm1->add_instruction(migraphx::op::slice{{2}, {1280}, {1920}}, input);
1613

1614
1615
1616
        auto c0 = mm1->add_instruction(migraphx::op::contiguous{}, slc0);
        auto c1 = mm1->add_instruction(migraphx::op::contiguous{}, slc1);
        auto c2 = mm1->add_instruction(migraphx::op::contiguous{}, slc2);
1617
1618

        std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 128, 10, 64};
1619
1620
1621
        auto r0                   = mm1->add_instruction(migraphx::op::reshape{lens}, c0);
        auto r1                   = mm1->add_instruction(migraphx::op::reshape{lens}, c1);
        auto r2                   = mm1->add_instruction(migraphx::op::reshape{lens}, c2);
1622

1623
1624
1625
        auto t0 = mm1->add_instruction(migraphx::op::transpose{perm0}, r0);
        auto t1 = mm1->add_instruction(migraphx::op::transpose{perm0}, r1);
        auto t2 = mm1->add_instruction(migraphx::op::transpose{perm1}, r2);
1626

1627
1628
1629
        auto sum = mm1->add_instruction(migraphx::op::add{}, t0, t1);
        auto ret = mm1->add_instruction(migraphx::op::dot{}, sum, t2);
        mm1->add_return({ret});
1630
1631
1632
1633
1634
1635

        return p1;
    };

    auto create_p2 = [&](std::size_t batch_size) {
        migraphx::program p2;
1636
        auto* mm2  = p2.get_main_module();
1637
        auto s     = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
1638
        auto input = mm2->add_parameter("input", s);
1639
        std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 128, 30, 64};
1640
        auto r                    = mm2->add_instruction(migraphx::op::reshape{lens}, input);
1641

1642
1643
1644
        auto slc0 = mm2->add_instruction(migraphx::op::slice{{2}, {0}, {10}}, r);
        auto slc1 = mm2->add_instruction(migraphx::op::slice{{2}, {10}, {20}}, r);
        auto slc2 = mm2->add_instruction(migraphx::op::slice{{2}, {20}, {30}}, r);
1645

1646
1647
1648
        auto t0 = mm2->add_instruction(migraphx::op::transpose{perm0}, slc0);
        auto t1 = mm2->add_instruction(migraphx::op::transpose{perm0}, slc1);
        auto t2 = mm2->add_instruction(migraphx::op::transpose{perm1}, slc2);
1649

1650
1651
1652
        auto sum = mm2->add_instruction(migraphx::op::add{}, t0, t1);
        auto ret = mm2->add_instruction(migraphx::op::dot{}, sum, t2);
        mm2->add_return({ret});
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668

        return p2;
    };

    auto test = [&](std::size_t batch_size) {
        auto p1 = create_p1(batch_size);
        run_pass(p1);
        auto p2 = create_p2(batch_size);
        EXPECT(p1.sort() == p2.sort());
    };

    test(1);
    test(4);
    test(8);
}

1669
TEST_CASE(reorder_reshape_slice_move_axis1)
1670
1671
1672
{
    auto create_p1 = [](std::size_t batch_size) {
        migraphx::program p1;
1673
1674
        auto* mm1 = p1.get_main_module();
        auto s    = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}};
1675
1676
        std::vector<int64_t> perm0 = {0, 2, 1, 3};
        std::vector<int64_t> perm1 = {0, 2, 3, 1};
1677
1678
1679
1680
        auto input                 = mm1->add_parameter("input", s);
        auto slc0 = mm1->add_instruction(migraphx::op::slice{{2}, {0}, {32}}, input);
        auto slc1 = mm1->add_instruction(migraphx::op::slice{{2}, {32}, {64}}, input);
        auto slc2 = mm1->add_instruction(migraphx::op::slice{{2}, {64}, {96}}, input);
1681

1682
1683
1684
        auto c0 = mm1->add_instruction(migraphx::op::contiguous{}, slc0);
        auto c1 = mm1->add_instruction(migraphx::op::contiguous{}, slc1);
        auto c2 = mm1->add_instruction(migraphx::op::contiguous{}, slc2);
1685

1686
        std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 64, 4, 32};
1687
1688
1689
        auto r0                   = mm1->add_instruction(migraphx::op::reshape{lens}, c0);
        auto r1                   = mm1->add_instruction(migraphx::op::reshape{lens}, c1);
        auto r2                   = mm1->add_instruction(migraphx::op::reshape{lens}, c2);
1690

1691
1692
1693
        auto t0 = mm1->add_instruction(migraphx::op::transpose{perm0}, r0);
        auto t1 = mm1->add_instruction(migraphx::op::transpose{perm0}, r1);
        auto t2 = mm1->add_instruction(migraphx::op::transpose{perm1}, r2);
1694

1695
1696
1697
        auto sum = mm1->add_instruction(migraphx::op::add{}, t0, t1);
        auto ret = mm1->add_instruction(migraphx::op::dot{}, sum, t2);
        mm1->add_return({ret});
1698
1699
1700
1701

        return p1;
    };

1702
1703
    auto create_p2 = [](std::size_t batch_size) {
        migraphx::program p;
1704
1705
        auto* mm = p.get_main_module();
        auto s   = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}};
1706
1707
        std::vector<int64_t> perm0 = {0, 2, 1, 3};
        std::vector<int64_t> perm1 = {0, 2, 3, 1};
1708
        auto input                 = mm->add_parameter("input", s);
1709
        std::vector<int64_t> lens  = {static_cast<int64_t>(batch_size), 64, 4, 96};
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
        auto rsp                   = mm->add_instruction(migraphx::op::reshape{lens}, input);
        auto slc0                  = mm->add_instruction(migraphx::op::slice{{3}, {0}, {32}}, rsp);
        auto t0                    = mm->add_instruction(migraphx::op::transpose{perm0}, slc0);
        auto slc1                  = mm->add_instruction(migraphx::op::slice{{3}, {32}, {64}}, rsp);
        auto t1                    = mm->add_instruction(migraphx::op::transpose{perm0}, slc1);
        auto slc2                  = mm->add_instruction(migraphx::op::slice{{3}, {64}, {96}}, rsp);
        auto t2                    = mm->add_instruction(migraphx::op::transpose{perm1}, slc2);

        auto sum = mm->add_instruction(migraphx::op::add{}, t0, t1);
        auto ret = mm->add_instruction(migraphx::op::dot{}, sum, t2);
        mm->add_return({ret});
1721
1722
1723
1724

        return p;
    };

1725
1726
    auto test = [&](std::size_t batch_size) {
        auto p1 = create_p1(batch_size);
1727
        auto p2 = create_p2(batch_size);
1728
1729
1730
1731
1732
1733
1734
1735
        run_pass(p1);
        EXPECT(p1.sort() == p2.sort());
    };

    test(4);
    test(8);
}

1736
1737
1738
1739
TEST_CASE(reorder_reshape_slice_move_axis2)
{
    auto create_p1 = [] {
        migraphx::program p1;
1740
        auto* mm1 = p1.get_main_module();
1741
        migraphx::shape s{migraphx::shape::float_type, {128, 96}};
1742
1743
1744
1745
        auto input = mm1->add_parameter("input", s);
        auto slc0  = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {32}}, input);
        auto slc1  = mm1->add_instruction(migraphx::op::slice{{1}, {32}, {64}}, input);
        auto slc2  = mm1->add_instruction(migraphx::op::slice{{1}, {64}, {96}}, input);
1746

1747
1748
1749
        auto c0 = mm1->add_instruction(migraphx::op::contiguous{}, slc0);
        auto c1 = mm1->add_instruction(migraphx::op::contiguous{}, slc1);
        auto c2 = mm1->add_instruction(migraphx::op::contiguous{}, slc2);
1750
1751

        std::vector<int64_t> lens = {1, 16, 8, 32};
1752
1753
1754
        auto r0                   = mm1->add_instruction(migraphx::op::reshape{lens}, c0);
        auto r1                   = mm1->add_instruction(migraphx::op::reshape{lens}, c1);
        auto r2                   = mm1->add_instruction(migraphx::op::reshape{lens}, c2);
1755

1756
1757
1758
        auto sum = mm1->add_instruction(migraphx::op::add{}, r0, r1);
        auto ret = mm1->add_instruction(migraphx::op::mul{}, sum, r2);
        mm1->add_return({ret});
1759
1760
1761
1762
1763
1764

        return p1;
    };

    auto create_p2 = [] {
        migraphx::program p;
1765
        auto* mm                  = p.get_main_module();
1766
        auto s                    = migraphx::shape{migraphx::shape::float_type, {128, 96}};
1767
        auto input                = mm->add_parameter("input", s);
1768
        std::vector<int64_t> lens = {1, 16, 8, 96};
1769
1770
1771
1772
        auto rsp                  = mm->add_instruction(migraphx::op::reshape{lens}, input);
        auto slc0                 = mm->add_instruction(migraphx::op::slice{{3}, {0}, {32}}, rsp);
        auto slc1                 = mm->add_instruction(migraphx::op::slice{{3}, {32}, {64}}, rsp);
        auto slc2                 = mm->add_instruction(migraphx::op::slice{{3}, {64}, {96}}, rsp);
1773

1774
1775
1776
        auto sum = mm->add_instruction(migraphx::op::add{}, slc0, slc1);
        auto ret = mm->add_instruction(migraphx::op::mul{}, sum, slc2);
        mm->add_return({ret});
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790

        return p;
    };

    auto p1 = create_p1();
    auto p2 = create_p2();
    run_pass(p1);
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(reorder_reshape_slice_not_apply)
{
    auto create_p = [] {
        migraphx::program p;
1791
        auto* mm = p.get_main_module();
1792
        migraphx::shape s{migraphx::shape::float_type, {128, 96}};
1793
1794
1795
1796
        auto input = mm->add_parameter("input", s);
        auto slc0  = mm->add_instruction(migraphx::op::slice{{1}, {0}, {32}}, input);
        auto slc1  = mm->add_instruction(migraphx::op::slice{{1}, {32}, {64}}, input);
        auto slc2  = mm->add_instruction(migraphx::op::slice{{1}, {64}, {96}}, input);
1797

1798
1799
1800
        auto c0 = mm->add_instruction(migraphx::op::contiguous{}, slc0);
        auto c1 = mm->add_instruction(migraphx::op::contiguous{}, slc1);
        auto c2 = mm->add_instruction(migraphx::op::contiguous{}, slc2);
1801
1802

        std::vector<int64_t> lens = {1, 16, 16, 16};
1803
1804
1805
        auto r0                   = mm->add_instruction(migraphx::op::reshape{lens}, c0);
        auto r1                   = mm->add_instruction(migraphx::op::reshape{lens}, c1);
        auto r2                   = mm->add_instruction(migraphx::op::reshape{lens}, c2);
1806

1807
1808
1809
        auto sum = mm->add_instruction(migraphx::op::add{}, r0, r1);
        auto ret = mm->add_instruction(migraphx::op::mul{}, sum, r2);
        mm->add_return({ret});
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819

        return p;
    };

    auto p1 = create_p();
    auto p2 = p1;
    run_pass(p1);
    EXPECT(p1.sort() == p2.sort());
}

1820
1821
1822
1823
TEST_CASE(reorder_reshape_slice_diff_dims)
{
    auto create_p1 = [](std::size_t batch_size) {
        migraphx::program p1;
1824
1825
        auto* mm1 = p1.get_main_module();
        auto s    = migraphx::shape{migraphx::shape::float_type, {batch_size, 96, 96}};
1826
1827
        std::vector<int64_t> perm0 = {0, 2, 1, 3};
        std::vector<int64_t> perm1 = {0, 2, 3, 1};
1828
1829
1830
1831
        auto input                 = mm1->add_parameter("input", s);
        auto slc0 = mm1->add_instruction(migraphx::op::slice{{2}, {0}, {32}}, input);
        auto slc1 = mm1->add_instruction(migraphx::op::slice{{2}, {32}, {64}}, input);
        auto slc2 = mm1->add_instruction(migraphx::op::slice{{2}, {64}, {96}}, input);
1832

1833
1834
1835
        auto c0 = mm1->add_instruction(migraphx::op::contiguous{}, slc0);
        auto c1 = mm1->add_instruction(migraphx::op::contiguous{}, slc1);
        auto c2 = mm1->add_instruction(migraphx::op::contiguous{}, slc2);
1836
1837
1838

        std::vector<int64_t> lens  = {static_cast<int64_t>(batch_size), 32, 3, 32};
        std::vector<int64_t> lens1 = {static_cast<int64_t>(batch_size), 48, 2, 32};
1839
1840
1841
        auto r0                    = mm1->add_instruction(migraphx::op::reshape{lens}, c0);
        auto r1                    = mm1->add_instruction(migraphx::op::reshape{lens}, c1);
        auto r2                    = mm1->add_instruction(migraphx::op::reshape{lens1}, c2);
1842

1843
        mm1->add_return({r0, r1, r2});
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863

        return p1;
    };

    auto test = [&](std::size_t batch_size) {
        auto p1 = create_p1(batch_size);
        auto p2 = p1;
        run_pass(p1);
        EXPECT(p1.sort() == p2.sort());
    };

    test(4);
    test(8);
}

TEST_CASE(reorder_slice_trans)
{
    std::vector<int64_t> perm = {0, 2, 1};
    auto create_p1            = [&](std::size_t batch_size) {
        migraphx::program p1;
1864
        auto* mm1  = p1.get_main_module();
1865
        auto s     = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
1866
1867
1868
1869
        auto input = mm1->add_parameter("input", s);
        auto slc0  = mm1->add_instruction(migraphx::op::slice{{2}, {0}, {640}}, input);
        auto slc1  = mm1->add_instruction(migraphx::op::slice{{2}, {640}, {1280}}, input);
        auto slc2  = mm1->add_instruction(migraphx::op::slice{{2}, {1280}, {1920}}, input);
1870

1871
1872
1873
        auto t0 = mm1->add_instruction(migraphx::op::transpose{perm}, slc0);
        auto t1 = mm1->add_instruction(migraphx::op::transpose{perm}, slc1);
        auto t2 = mm1->add_instruction(migraphx::op::transpose{perm}, slc2);
1874

1875
1876
1877
        auto sum = mm1->add_instruction(migraphx::op::add{}, t0, t1);
        auto ret = mm1->add_instruction(migraphx::op::mul{}, sum, t2);
        mm1->add_return({ret});
1878
1879
1880
1881
1882
1883

        return p1;
    };

    auto create_p2 = [&](std::size_t batch_size) {
        migraphx::program p2;
1884
        auto* mm2  = p2.get_main_module();
1885
        auto s     = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
1886
1887
        auto input = mm2->add_parameter("input", s);
        auto r     = mm2->add_instruction(migraphx::op::transpose{perm}, input);
1888

1889
1890
1891
        auto slc0 = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {640}}, r);
        auto slc1 = mm2->add_instruction(migraphx::op::slice{{1}, {640}, {1280}}, r);
        auto slc2 = mm2->add_instruction(migraphx::op::slice{{1}, {1280}, {1920}}, r);
1892

1893
1894
1895
        auto sum = mm2->add_instruction(migraphx::op::add{}, slc0, slc1);
        auto ret = mm2->add_instruction(migraphx::op::mul{}, sum, slc2);
        mm2->add_return({ret});
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914

        return p2;
    };

    auto test = [&](std::size_t batch_size) {
        auto p1 = create_p1(batch_size);
        run_pass(p1);
        auto p2 = create_p2(batch_size);
        EXPECT(p1.sort() == p2.sort());
    };

    test(1);
    test(8);
}

TEST_CASE(reorder_slice_trans_diff_perm)
{
    auto create_p1 = [](std::size_t batch_size) {
        migraphx::program p1;
1915
1916
        auto* mm1 = p1.get_main_module();
        auto s    = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
1917
1918
        std::vector<int64_t> perm0 = {0, 2, 1};
        std::vector<int64_t> perm1 = {0, 1, 2};
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
        auto input                 = mm1->add_parameter("input", s);
        auto slc0 = mm1->add_instruction(migraphx::op::slice{{2}, {0}, {640}}, input);
        auto slc1 = mm1->add_instruction(migraphx::op::slice{{2}, {640}, {1280}}, input);
        auto slc2 = mm1->add_instruction(migraphx::op::slice{{2}, {1280}, {1920}}, input);

        auto t0 = mm1->add_instruction(migraphx::op::transpose{perm0}, slc0);
        auto t1 = mm1->add_instruction(migraphx::op::transpose{perm0}, slc1);
        auto t2 = mm1->add_instruction(migraphx::op::transpose{perm1}, slc2);

        auto sum = mm1->add_instruction(migraphx::op::add{}, t0, t1);
        auto ret = mm1->add_instruction(migraphx::op::dot{}, sum, t2);
        mm1->add_return({ret});
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945

        return p1;
    };

    auto test = [&](std::size_t batch_size) {
        auto p1 = create_p1(batch_size);
        run_pass(p1);
        auto p2 = p1;
        EXPECT(p1.sort() == p2.sort());
    };

    test(1);
    test(4);
}

Paul's avatar
Paul committed
1946
int main(int argc, const char* argv[]) { test::run(argc, argv); }