simplify_algebra_test.cpp 96.7 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/pass_manager.hpp>
Paul's avatar
Paul committed
4
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
5
6
7
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
8
#include <basic_ops.hpp>
9
10
#include <migraphx/make_op.hpp>

Paul's avatar
Paul committed
11
12
#include <test.hpp>

13
void run_pass(migraphx::program& p)
Paul's avatar
Paul committed
14
{
15
16
    migraphx::run_passes(*p.get_main_module(),
                         {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}});
17
}
Paul's avatar
Paul committed
18

Paul's avatar
Paul committed
19
TEST_CASE(simplify_add1)
Paul's avatar
Paul committed
20
{
Paul's avatar
Paul committed
21
    migraphx::program p1;
Paul's avatar
Paul committed
22
    {
23
24
25
26
27
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto one  = mm1->add_literal(1);
        auto two  = mm1->add_literal(2);
28
29
30
        auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, one);
        auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, two);
        auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
31
        mm1->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
32
    }
33
    run_pass(p1);
Paul's avatar
Paul committed
34

Paul's avatar
Paul committed
35
    migraphx::program p2;
Paul's avatar
Paul committed
36
    {
37
38
39
40
41
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto one  = mm2->add_literal(1);
        auto two  = mm2->add_literal(2);
42
43
44
        auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, two);
        auto sum2 = mm2->add_instruction(migraphx::make_op("add"), x, y);
        auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum2, sum1);
45
        mm2->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
46
47
48
49
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
50
TEST_CASE(simplify_add2)
Paul's avatar
Paul committed
51
{
Paul's avatar
Paul committed
52
    migraphx::program p1;
Paul's avatar
Paul committed
53
    {
54
55
56
57
58
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto one  = mm1->add_literal(1);
        auto two  = mm1->add_literal(2);
59
60
61
        auto sum1 = mm1->add_instruction(migraphx::make_op("add"), one, x);
        auto sum2 = mm1->add_instruction(migraphx::make_op("add"), two, y);
        auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
62
        mm1->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
63
    }
64
    run_pass(p1);
Paul's avatar
Paul committed
65

Paul's avatar
Paul committed
66
    migraphx::program p2;
Paul's avatar
Paul committed
67
    {
68
69
70
71
72
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto one  = mm2->add_literal(1);
        auto two  = mm2->add_literal(2);
73
74
75
        auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, two);
        auto sum2 = mm2->add_instruction(migraphx::make_op("add"), x, y);
        auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum2, sum1);
76
        mm2->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
77
78
79
80
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
81
TEST_CASE(simplify_add3)
Paul's avatar
Paul committed
82
{
Paul's avatar
Paul committed
83
    migraphx::program p1;
Paul's avatar
Paul committed
84
    {
85
86
87
88
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto one  = mm1->add_literal(1);
        auto two  = mm1->add_literal(2);
89
90
91
        auto sum1 = mm1->add_instruction(migraphx::make_op("add"), one, x);
        auto sum2 = mm1->add_instruction(migraphx::make_op("add"), one, two);
        auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
92
        mm1->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
93
    }
94
    run_pass(p1);
Paul's avatar
Paul committed
95

Paul's avatar
Paul committed
96
    migraphx::program p2;
Paul's avatar
Paul committed
97
    {
98
99
100
101
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto one  = mm2->add_literal(1);
        auto two  = mm2->add_literal(2);
102
103
104
        auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, two);
        auto sum2 = mm2->add_instruction(migraphx::make_op("add"), one, sum1);
        auto sum3 = mm2->add_instruction(migraphx::make_op("add"), x, sum2);
105
        mm2->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
106
107
108
109
    }
    EXPECT(p1 == p2);
}

110
111
112
113
114
115
116
TEST_CASE(simplify_add_broadcast1)
{
    migraphx::shape inner{migraphx::shape::int32_type, {2}};
    migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
    migraphx::op::broadcast b{1, {1, 2, 3, 3}};
    migraphx::program p1;
    {
117
118
119
120
121
122
123
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", outer);
        auto y    = mm1->add_parameter("y", outer);
        auto one  = mm1->add_literal({inner, {1, 1}});
        auto oneb = mm1->add_instruction(b, one);
        auto two  = mm1->add_literal({inner, {2, 2}});
        auto twob = mm1->add_instruction(b, two);
124
125
126
        auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
127
        mm1->add_instruction(pass_op{}, sum3);
128
    }
129
    run_pass(p1);
130
131
132

    migraphx::program p2;
    {
133
134
135
136
137
        auto* mm2  = p2.get_main_module();
        auto x     = mm2->add_parameter("x", outer);
        auto y     = mm2->add_parameter("y", outer);
        auto one   = mm2->add_literal({inner, {1, 1}});
        auto two   = mm2->add_literal({inner, {2, 2}});
138
        auto sum1  = mm2->add_instruction(migraphx::make_op("add"), one, two);
139
        auto sum1b = mm2->add_instruction(b, sum1);
140
141
        auto sum2  = mm2->add_instruction(migraphx::make_op("add"), x, y);
        auto sum3  = mm2->add_instruction(migraphx::make_op("add"), sum2, sum1b);
142
        mm2->add_instruction(pass_op{}, sum3);
143
144
145
146
147
148
149
150
151
152
153
    }
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_add_broadcast2)
{
    migraphx::shape inner{migraphx::shape::int32_type, {2}};
    migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
    migraphx::op::broadcast b{1, {1, 2, 3, 3}};
    auto create_program = [&] {
        migraphx::program p;
154
155
156
157
158
159
        auto* mm  = p.get_main_module();
        auto x    = mm->add_parameter("x", outer);
        auto y    = mm->add_parameter("y", outer);
        auto one  = mm->add_literal({inner, {1, 1}});
        auto oneb = mm->add_instruction(b, one);
        auto two = mm->add_literal({outer, {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}});
160
161
162
        auto sum1 = mm->add_instruction(migraphx::make_op("add"), x, y);
        auto sum2 = mm->add_instruction(migraphx::make_op("add"), oneb, two);
        auto sum3 = mm->add_instruction(migraphx::make_op("add"), sum2, sum1);
163
        mm->add_instruction(pass_op{}, sum3);
164
165
166
        return p;
    };
    migraphx::program p1 = create_program();
167
    run_pass(p1);
168
169
170
171
172

    migraphx::program p2 = create_program();
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
173
// TODO: Add test case
174
// TEST_CASE(simplify_add4)
Paul's avatar
Paul committed
175
176
void simplify_add4()
{
Paul's avatar
Paul committed
177
    migraphx::program p1;
Paul's avatar
Paul committed
178
    {
179
180
181
182
183
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto one  = mm1->add_literal(1);
        auto two  = mm1->add_literal(2);
184
185
186
        auto sum1 = mm1->add_instruction(migraphx::make_op("add"), one, x);
        auto sum2 = mm1->add_instruction(migraphx::make_op("add"), sum1, y);
        auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum2, two);
187
        mm1->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
188
    }
189
    run_pass(p1);
Paul's avatar
Paul committed
190

Paul's avatar
Paul committed
191
    migraphx::program p2;
Paul's avatar
Paul committed
192
    {
193
194
195
196
197
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto one  = mm2->add_literal(1);
        auto two  = mm2->add_literal(2);
198
199
200
        auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, two);
        auto sum2 = mm2->add_instruction(migraphx::make_op("add"), x, y);
        auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum2, sum1);
201
        mm2->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
202
203
204
205
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
206
207
208
TEST_CASE(simplify_mul_conv1)
{
    migraphx::program p;
209
210
211
212
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
213
214
215
216
217
218
219
220
221
    auto conv = mm->add_instruction(
        migraphx::make_op("convolution",
                          {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
        x,
        w);
    auto a = mm->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}}));
    auto b = mm->add_instruction(
        migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 256, 14, 14}}}), a);
    auto mul = mm->add_instruction(migraphx::make_op("mul"), conv, b);
222
    mm->add_instruction(pass_op{}, mul);
Paul's avatar
Paul committed
223
    EXPECT(conv->outputs().front()->name() == "mul");
224
    run_pass(p);
Paul's avatar
Paul committed
225
226
227
228
229
    auto new_conv =
        std::find_if(p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; });
    EXPECT(new_conv->outputs().front()->name() != "mul");
}

230
231
232
233
TEST_CASE(simplify_mul_slice_conv1)
{
    migraphx::program p1;
    {
234
235
236
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
        auto w    = mm1->add_literal(
237
            migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
238
239
240
        auto conv   = mm1->add_instruction(migraphx::make_op("convolution"), x, w);
        auto slice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
241
        auto a = mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
242
243
244
245
246
247
        auto b = mm1->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a);
        auto mul    = mm1->add_instruction(migraphx::make_op("mul"), slice1, b);
        auto slice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv);
        auto add = mm1->add_instruction(migraphx::make_op("add"), mul, slice2);
248
        mm1->add_instruction(pass_op{}, add);
249
250
251
252
253
    }
    run_pass(p1);

    migraphx::program p2;
    {
254
255
256
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
        auto w    = mm2->add_literal(
257
            migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
258
259
        auto wslice1 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {384}}}), w);
260
        auto a = mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
261
262
263
264
265
266
267
268
269
270
271
272
273
        auto b = mm2->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 0}, {"dims", {384, 1024, 1, 1}}}), a);
        auto mul     = mm2->add_instruction(migraphx::make_op("mul"), b, wslice1);
        auto wslice2 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w);
        auto concat =
            mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), mul, wslice2);
        auto conv   = mm2->add_instruction(migraphx::make_op("convolution"), x, concat);
        auto slice1 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
        auto slice2 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv);
        auto add = mm2->add_instruction(migraphx::make_op("add"), slice1, slice2);
274
        mm2->add_instruction(pass_op{}, add);
275
276
277
278
279
280
281
282
    }
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_mul_slice_conv_overlapping_slice)
{
    migraphx::program p1;
    {
283
284
285
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
        auto w    = mm1->add_literal(
286
            migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
287
288
289
        auto conv   = mm1->add_instruction(migraphx::make_op("convolution"), x, w);
        auto slice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
290
        auto a = mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
291
292
293
294
295
296
        auto b = mm1->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a);
        auto mul    = mm1->add_instruction(migraphx::make_op("mul"), slice1, b);
        auto slice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {383}}, {"ends", {767}}}), conv);
        auto add = mm1->add_instruction(migraphx::make_op("add"), mul, slice2);
297
        mm1->add_instruction(pass_op{}, add);
298
299
300
301
302
303
304
305
306
307
    }
    migraphx::program p2 = p1;
    run_pass(p1);
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_mul_slice_conv_not_all_slice)
{
    migraphx::program p1;
    {
308
309
310
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
        auto w    = mm1->add_literal(
311
            migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
312
313
314
        auto conv   = mm1->add_instruction(migraphx::make_op("convolution"), x, w);
        auto slice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
315
        auto a = mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
316
317
318
        auto b = mm1->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a);
        auto mul = mm1->add_instruction(migraphx::make_op("mul"), slice1, b);
319
        auto c   = mm1->add_literal(
320
            migraphx::generate_literal({migraphx::shape::int32_type, {1, 768, 17, 17}}));
321
322
        auto add    = mm1->add_instruction(migraphx::make_op("add"), conv, c);
        auto concat = mm1->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), mul, add);
323
        mm1->add_instruction(pass_op{}, concat);
324
325
326
327
328
329
    }
    migraphx::program p2 = p1;
    run_pass(p1);
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
330
331
332
333
TEST_CASE(simplify_mul_add)
{
    migraphx::program p1;
    {
334
335
336
337
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto one  = mm1->add_literal(1);
        auto two  = mm1->add_literal(2);
338
339
        auto sum  = mm1->add_instruction(migraphx::make_op("add"), one, x);
        auto mul  = mm1->add_instruction(migraphx::make_op("mul"), sum, two);
340
        mm1->add_instruction(pass_op{}, mul);
Paul's avatar
Paul committed
341
    }
342
    run_pass(p1);
Paul's avatar
Paul committed
343
344
345

    migraphx::program p2;
    {
346
347
348
349
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto one  = mm2->add_literal(1);
        auto two  = mm2->add_literal(2);
350
351
352
        auto mul1 = mm2->add_instruction(migraphx::make_op("mul"), two, x);
        auto mul2 = mm2->add_instruction(migraphx::make_op("mul"), two, one);
        auto sum  = mm2->add_instruction(migraphx::make_op("add"), mul1, mul2);
353
        mm2->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
354
355
356
357
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
358
359
360
361
362
TEST_CASE(simplify_inner_broadcast)
{
    auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
    migraphx::program p1;
    {
363
364
365
366
367
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto xb   = mm1->add_instruction(b, x);
        auto yb   = mm1->add_instruction(b, y);
368
        auto sum  = mm1->add_instruction(migraphx::make_op("add"), xb, yb);
369
        mm1->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
370
    }
371
    run_pass(p1);
Paul's avatar
Paul committed
372
373
374

    migraphx::program p2;
    {
375
376
377
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
378
        auto sum  = mm2->add_instruction(migraphx::make_op("add"), x, y);
379
380
        auto sumb = mm2->add_instruction(b, sum);
        mm2->add_instruction(pass_op{}, sumb);
Paul's avatar
Paul committed
381
382
383
384
    }
    EXPECT(p1 == p2);
}

385
386
387
TEST_CASE(simplify_add_conv1)
{
    migraphx::program p;
388
389
390
391
392
393
394
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 28}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 3, 3}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
    auto v = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 3, 3}}));
395
396
397
    auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
    auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v);
    auto sum   = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
398
    mm->add_instruction(pass_op{}, sum);
399
    auto s = p.get_output_shapes().back();
400
    run_pass(p);
401
    EXPECT(s == p.get_output_shapes().back());
402
403
404
405
406
407
408
    EXPECT(std::count_if(
               p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
}

TEST_CASE(simplify_add_conv_no_fusion_7x7_diff_strides)
{
    migraphx::program p;
409
410
411
412
413
414
415
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 7, 7}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
    auto v = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 7, 7}}));
416
417
418
419
    auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
    auto conv2 = mm->add_instruction(
        migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {3, 3}}}), y, v);
    auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
420
    mm->add_instruction(pass_op{}, sum);
421
    auto s = p.get_output_shapes().back();
422
    run_pass(p);
423
    EXPECT(s == p.get_output_shapes().back());
424
425
426
427
428
429
430
431
    // No fusion
    EXPECT(std::count_if(
               p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
}

TEST_CASE(simplify_add_conv_1x1_diff_strides1)
{
    migraphx::program p;
432
433
434
435
436
437
438
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
    auto v = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
439
440
441
442
    auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
    auto conv2 = mm->add_instruction(
        migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v);
    auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
443
    mm->add_instruction(pass_op{}, sum);
444
    auto s = p.get_output_shapes().back();
445
    run_pass(p);
446
    EXPECT(s == p.get_output_shapes().back());
447
448
449
450
451
452
453
    EXPECT(std::count_if(
               p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
}

TEST_CASE(simplify_add_conv_1x1_diff_strides2)
{
    migraphx::program p;
454
455
456
457
458
459
460
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 28}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 14, 14}});
    auto v = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
461
462
463
464
    auto conv1 = mm->add_instruction(
        migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), x, w);
    auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v);
    auto sum   = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
465
    mm->add_instruction(pass_op{}, sum);
466
    auto s = p.get_output_shapes().back();
467
468
469
470
471
472
473
474
475
    run_pass(p);
    EXPECT(s == p.get_output_shapes().back());
    EXPECT(std::count_if(
               p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
}

TEST_CASE(simplify_add_conv_1x1_diff_strides_odd)
{
    migraphx::program p;
476
477
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 54, 83, 83}});
478
    auto w =
479
480
        mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {54, 54, 1, 1}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 54, 165, 165}});
481
    auto v =
482
        mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {54, 54, 1, 1}}));
483
484
485
486
    auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
    auto conv2 = mm->add_instruction(
        migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v);
    auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
487
    mm->add_instruction(pass_op{}, sum);
488
    auto s = p.get_output_shapes().back();
489
    run_pass(p);
490
    EXPECT(s == p.get_output_shapes().back());
491
492
493
494
495
496
497
    EXPECT(std::count_if(
               p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
}

TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides1)
{
    migraphx::program p;
498
499
500
501
502
503
504
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 14}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 14, 14}});
    auto v = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
505
506
507
508
    auto conv1 = mm->add_instruction(
        migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 1}}}), x, w);
    auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v);
    auto sum   = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
509
    mm->add_instruction(pass_op{}, sum);
510
    auto s = p.get_output_shapes().back();
511
    run_pass(p);
512
    EXPECT(s == p.get_output_shapes().back());
513
514
515
516
517
518
519
520
    // No fusion
    EXPECT(std::count_if(
               p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
}

TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides2)
{
    migraphx::program p;
521
522
523
524
525
526
527
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 14}});
    auto v = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
528
529
530
531
    auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
    auto conv2 = mm->add_instruction(
        migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 1}}}), y, v);
    auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
532
    mm->add_instruction(pass_op{}, sum);
533
    auto s = p.get_output_shapes().back();
534
    run_pass(p);
535
    EXPECT(s == p.get_output_shapes().back());
536
537
538
539
540
    // No fusion
    EXPECT(std::count_if(
               p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
}

541
542
543
544
545
TEST_CASE(simplify_concat_add_relu)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {1}};
    migraphx::program p1;
    {
546
547
548
549
550
551
552
553
554
555
556
        auto* mm1  = p1.get_main_module();
        auto x     = mm1->add_parameter("x", s);
        auto y     = mm1->add_parameter("y", s);
        auto one   = mm1->add_literal({s, {1}});
        auto two   = mm1->add_literal({s, {2}});
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, one);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, two);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto concat =
            mm1->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), relu1, relu2);
557
        mm1->add_instruction(pass_op{}, concat);
558
559
560
561
562
    }
    run_pass(p1);

    migraphx::program p2;
    {
563
564
565
566
567
        auto* mm2    = p2.get_main_module();
        auto x       = mm2->add_parameter("x", s);
        auto y       = mm2->add_parameter("y", s);
        auto one     = mm2->add_literal({s, {1}});
        auto two     = mm2->add_literal({s, {2}});
568
569
570
571
        auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y);
        auto concat2 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
        auto sum     = mm2->add_instruction(migraphx::make_op("add"), concat1, concat2);
        auto relu    = mm2->add_instruction(migraphx::make_op("relu"), sum);
572
        mm2->add_instruction(pass_op{}, relu);
573
574
575
576
    }
    EXPECT(p1 == p2);
}

577
578
579
580
581
TEST_CASE(simplify_concat_add_relu_partial)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {1}};
    migraphx::program p1;
    {
582
583
584
585
586
587
588
589
590
591
592
593
        auto* mm1  = p1.get_main_module();
        auto x     = mm1->add_parameter("x", s);
        auto y     = mm1->add_parameter("y", s);
        auto one   = mm1->add_literal({s, {1}});
        auto two   = mm1->add_literal({s, {2}});
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, one);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, two);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto sum3  = mm1->add_instruction(migraphx::make_op("add"), x, y);
        auto concat =
            mm1->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), sum3, relu1, relu2);
594
        mm1->add_instruction(pass_op{}, concat);
595
596
597
598
599
    }
    run_pass(p1);

    migraphx::program p2;
    {
600
601
602
603
604
        auto* mm2    = p2.get_main_module();
        auto x       = mm2->add_parameter("x", s);
        auto y       = mm2->add_parameter("y", s);
        auto one     = mm2->add_literal({s, {1}});
        auto two     = mm2->add_literal({s, {2}});
605
606
607
608
609
610
        auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y);
        auto concat2 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
        auto sum1    = mm2->add_instruction(migraphx::make_op("add"), concat1, concat2);
        auto relu    = mm2->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2    = mm2->add_instruction(migraphx::make_op("add"), x, y);
        auto concat  = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), sum2, relu);
611
        mm2->add_instruction(pass_op{}, concat);
612
613
614
615
616
617
618
619
620
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_concat_add_relu_partial_broadcast)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
    migraphx::program p1;
    {
621
622
623
624
625
626
627
628
629
630
631
        auto* mm1 = p1.get_main_module();
        auto b    = migraphx::op::broadcast{1, {2, 1, 4, 5}};
        auto x    = mm1->add_parameter("x", s);
        auto y    = mm1->add_parameter("y", s);
        auto one  = mm1->add_literal(1);
        auto oneb = mm1->add_instruction(b, one);
        auto two  = mm1->add_literal(2);
        auto twob = mm1->add_instruction(b, two);
        auto sum  = mm1->add_instruction(migraphx::make_op("add"), x, y);
        auto concat =
            mm1->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), sum, oneb, twob);
632
        mm1->add_instruction(pass_op{}, concat);
633
634
635
636
637
    }
    run_pass(p1);

    migraphx::program p2;
    {
638
        auto* mm2    = p2.get_main_module();
639
        auto b       = migraphx::op::broadcast{1, {2, 2, 4, 5}};
640
641
642
643
        auto x       = mm2->add_parameter("x", s);
        auto y       = mm2->add_parameter("y", s);
        auto one     = mm2->add_literal(1);
        auto two     = mm2->add_literal(2);
644
        auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
645
        auto concatb = mm2->add_instruction(b, concat1);
646
647
648
        auto sum     = mm2->add_instruction(migraphx::make_op("add"), x, y);
        auto concat2 =
            mm2->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), sum, concatb);
649
        mm2->add_instruction(pass_op{}, concat2);
650
651
652
653
    }
    EXPECT(p1.sort() == p2.sort());
}

654
655
656
657
658
TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
    migraphx::program p1;
    {
659
660
661
662
663
664
665
666
667
668
669
670
671
672
        auto* mm1  = p1.get_main_module();
        auto b     = migraphx::op::broadcast{1, {2, 1, 4, 5}};
        auto x     = mm1->add_parameter("x", s);
        auto y     = mm1->add_parameter("y", s);
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto concat =
            mm1->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), relu1, relu2);
673
        mm1->add_instruction(pass_op{}, concat);
674
675
676
677
678
    }
    run_pass(p1);

    migraphx::program p2;
    {
679
        auto* mm2     = p2.get_main_module();
680
        auto b        = migraphx::op::broadcast{1, {2, 2, 4, 5}};
681
682
683
684
        auto x        = mm2->add_parameter("x", s);
        auto y        = mm2->add_parameter("y", s);
        auto one      = mm2->add_literal(1);
        auto two      = mm2->add_literal(2);
685
686
        auto concat1  = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y);
        auto concat2  = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
687
        auto concat2b = mm2->add_instruction(b, concat2);
688
689
        auto sum      = mm2->add_instruction(migraphx::make_op("add"), concat1, concat2b);
        auto relu     = mm2->add_instruction(migraphx::make_op("relu"), sum);
690
        mm2->add_instruction(pass_op{}, relu);
691
692
693
694
695
696
697
698
699
    }
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
    migraphx::program p1;
    {
700
701
702
703
704
705
706
707
708
709
710
711
712
713
        auto* mm1  = p1.get_main_module();
        auto b     = migraphx::op::broadcast{1, {2, 1, 4, 5}};
        auto x     = mm1->add_parameter("x", s);
        auto y     = mm1->add_parameter("y", s);
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto concat =
            mm1->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), relu1, relu2);
714
        mm1->add_instruction(pass_op{}, concat);
715
716
717
718
719
    }
    run_pass(p1);

    migraphx::program p2;
    {
720
        auto* mm2    = p2.get_main_module();
721
        auto b       = migraphx::op::broadcast{1, {2, 1, 4, 5}};
722
723
724
725
726
727
        auto x       = mm2->add_parameter("x", s);
        auto y       = mm2->add_parameter("y", s);
        auto one     = mm2->add_literal(1);
        auto oneb    = mm2->add_instruction(b, one);
        auto two     = mm2->add_literal(2);
        auto twob    = mm2->add_instruction(b, two);
728
729
730
731
        auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y);
        auto concat2 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), oneb, twob);
        auto sum     = mm2->add_instruction(migraphx::make_op("add"), concat1, concat2);
        auto relu    = mm2->add_instruction(migraphx::make_op("relu"), sum);
732
        mm2->add_instruction(pass_op{}, relu);
733
734
735
736
    }
    EXPECT(p1 == p2);
}

737
738
739
740
TEST_CASE(simplify_div_const)
{
    migraphx::program p1;
    {
741
742
743
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto two  = mm1->add_literal(2);
744
        mm1->add_instruction(migraphx::make_op("div"), x, two);
745
746
747
748
749
    }
    run_pass(p1);

    migraphx::program p2;
    {
750
751
752
        auto* mm2  = p2.get_main_module();
        auto x     = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto two   = mm2->add_literal(2);
753
754
        auto recip = mm2->insert_instruction(std::next(two), migraphx::make_op("recip"), two);
        mm2->add_instruction(migraphx::make_op("mul"), x, recip);
755
756
757
758
759
760
761
762
    }
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_sub_const)
{
    migraphx::program p1;
    {
763
764
765
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto two  = mm1->add_literal(2);
766
        mm1->add_instruction(migraphx::make_op("sub"), x, two);
767
768
769
770
771
    }
    run_pass(p1);

    migraphx::program p2;
    {
772
773
774
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto two  = mm2->add_literal(2);
775
776
        auto neg  = mm2->insert_instruction(std::next(two), migraphx::make_op("neg"), two);
        mm2->add_instruction(migraphx::make_op("add"), x, neg);
777
778
779
780
    }
    EXPECT(p1 == p2);
}

kahmed10's avatar
kahmed10 committed
781
782
783
784
TEST_CASE(simplify_rsqrt)
{
    migraphx::program p1;
    {
785
786
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
787
788
        auto sqrt = mm1->add_instruction(migraphx::make_op("sqrt"), x);
        mm1->add_instruction(migraphx::make_op("recip"), sqrt);
kahmed10's avatar
kahmed10 committed
789
790
791
792
793
    }
    run_pass(p1);

    migraphx::program p2;
    {
794
795
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
796
        mm2->add_instruction(migraphx::make_op("rsqrt"), x);
kahmed10's avatar
kahmed10 committed
797
798
799
800
801
802
803
804
    }
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_rsqrt_multi_use)
{
    migraphx::program p1;
    {
805
806
        auto* mm1  = p1.get_main_module();
        auto x     = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
807
808
809
810
        auto sqrt  = mm1->add_instruction(migraphx::make_op("sqrt"), x);
        auto add   = mm1->add_instruction(migraphx::make_op("add"), sqrt, sqrt);
        auto rsqrt = mm1->add_instruction(migraphx::make_op("recip"), sqrt);
        mm1->add_instruction(migraphx::make_op("add"), rsqrt, add);
kahmed10's avatar
kahmed10 committed
811
812
813
814
815
816
817
    }
    migraphx::program p2{p1};

    run_pass(p1);
    EXPECT(p1 == p2);
}

818
819
820
821
822
823
TEST_CASE(simplify_slice_concat)
{
    auto s = migraphx::shape{migraphx::shape::float_type, {256}};

    migraphx::program p1;
    {
824
825
826
        auto* mm1    = p1.get_main_module();
        auto x       = mm1->add_parameter("x", s);
        auto y       = mm1->add_parameter("y", s);
827
828
829
830
831
832
833
834
835
836
        auto xslice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {128}}}), x);
        auto xslice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {128}}, {"ends", {256}}}), x);
        auto yslice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {128}}}), y);
        auto yslice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {128}}, {"ends", {256}}}), y);
        auto concat = mm1->add_instruction(
            migraphx::make_op("concat", {{"axis", 0}}), xslice1, xslice2, yslice1, yslice2);
837
        mm1->add_instruction(pass_op{}, concat);
838
839
840
841
842
    }
    run_pass(p1);

    migraphx::program p2;
    {
843
844
845
        auto* mm2   = p2.get_main_module();
        auto x      = mm2->add_parameter("x", s);
        auto y      = mm2->add_parameter("y", s);
846
        auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y);
847
        mm2->add_instruction(pass_op{}, concat);
848
849
850
851
852
853
854
855
856
857
    }
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_slice_concat_non_uniform)
{
    auto s = migraphx::shape{migraphx::shape::float_type, {256}};

    migraphx::program p1;
    {
858
859
860
        auto* mm1    = p1.get_main_module();
        auto x       = mm1->add_parameter("x", s);
        auto y       = mm1->add_parameter("y", s);
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
        auto xslice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {64}}}), x);
        auto xslice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {192}}}), x);
        auto xslice3 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {192}}, {"ends", {256}}}), x);
        auto yslice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {64}}}), y);
        auto yslice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {192}}}), y);
        auto yslice3 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {192}}, {"ends", {256}}}), y);
        auto concat = mm1->add_instruction(migraphx::make_op("concat", {{"axis", 0}}),
                                           xslice1,
                                           xslice2,
                                           xslice3,
                                           yslice1,
                                           yslice2,
                                           yslice3);
880
        mm1->add_instruction(pass_op{}, concat);
881
882
883
884
885
    }
    run_pass(p1);

    migraphx::program p2;
    {
886
887
888
        auto* mm2   = p2.get_main_module();
        auto x      = mm2->add_parameter("x", s);
        auto y      = mm2->add_parameter("y", s);
889
        auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y);
890
        mm2->add_instruction(pass_op{}, concat);
891
892
893
894
895
896
897
898
899
900
901
    }

    EXPECT(p1 == p2);
}

TEST_CASE(simplify_slice_concat_flipped)
{
    auto s = migraphx::shape{migraphx::shape::float_type, {256}};

    migraphx::program p1;
    {
902
903
904
        auto* mm1    = p1.get_main_module();
        auto x       = mm1->add_parameter("x", s);
        auto y       = mm1->add_parameter("y", s);
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
        auto xslice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {64}}}), x);
        auto xslice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {192}}, {"ends", {256}}}), x);
        auto xslice3 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {192}}}), x);
        auto yslice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {64}}}), y);
        auto yslice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {192}}, {"ends", {256}}}), y);
        auto yslice3 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {192}}}), y);
        auto concat = mm1->add_instruction(migraphx::make_op("concat", {{"axis", 0}}),
                                           xslice1,
                                           xslice2,
                                           xslice3,
                                           yslice1,
                                           yslice2,
                                           yslice3);
924
        mm1->add_instruction(pass_op{}, concat);
925
926
927
928
929
930
931
    }
    migraphx::program p2 = p1;
    run_pass(p1);

    EXPECT(p1 == p2);
}

932
933
934
935
936
TEST_CASE(simplify_split_add_relu)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
    migraphx::program p1;
    {
937
        auto* mm1  = p1.get_main_module();
938
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
939
        auto input = mm1->add_parameter("input", s);
940
941
942
943
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
944
945
946
947
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
948
949
950
951
952
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto add   = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
953
        mm1->add_instruction(pass_op{}, add);
954
955
956
957
958
    }
    run_pass(p1);

    migraphx::program p2;
    {
959
        auto* mm2    = p2.get_main_module();
960
        auto b       = migraphx::op::broadcast{1, {3, 2, 4}};
961
962
963
        auto input   = mm2->add_parameter("input", s);
        auto one     = mm2->add_literal(1);
        auto two     = mm2->add_literal(2);
964
        auto concat  = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
965
        auto concatb = mm2->add_instruction(b, concat);
966
967
968
969
970
971
972
        auto sum     = mm2->add_instruction(migraphx::make_op("add"), input, concatb);
        auto relu    = mm2->add_instruction(migraphx::make_op("relu"), sum);
        auto x       = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), relu);
        auto y = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), relu);
        auto add = mm2->add_instruction(migraphx::make_op("add"), x, y);
973
        mm2->add_instruction(pass_op{}, add);
974
975
976
977
978
979
980
981
982
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_split_add_relu_reshape)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
    migraphx::program p1;
    {
983
984
985
986
987
988
989
990
        auto* mm1  = p1.get_main_module();
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
        auto r     = migraphx::op::reshape{{3, 4}};
        auto input = mm1->add_parameter("input", s);
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
991
992
993
994
        auto one      = mm1->add_literal(1);
        auto oneb     = mm1->add_instruction(b, one);
        auto two      = mm1->add_literal(2);
        auto twob     = mm1->add_instruction(b, two);
995
996
        auto sum1     = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1    = mm1->add_instruction(migraphx::make_op("relu"), sum1);
997
        auto reshape1 = mm1->add_instruction(r, relu1);
998
999
        auto sum2     = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2    = mm1->add_instruction(migraphx::make_op("relu"), sum2);
1000
        auto reshape2 = mm1->add_instruction(r, relu2);
1001
        auto add      = mm1->add_instruction(migraphx::make_op("add"), reshape1, reshape2);
1002
        mm1->add_instruction(pass_op{}, add);
1003
1004
1005
1006
1007
    }
    run_pass(p1);

    migraphx::program p2;
    {
1008
        auto* mm2    = p2.get_main_module();
1009
        auto b       = migraphx::op::broadcast{1, {3, 2, 4}};
1010
1011
1012
        auto input   = mm2->add_parameter("input", s);
        auto one     = mm2->add_literal(1);
        auto two     = mm2->add_literal(2);
1013
        auto concat  = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
1014
        auto concatb = mm2->add_instruction(b, concat);
1015
1016
1017
1018
1019
1020
1021
1022
        auto sum     = mm2->add_instruction(migraphx::make_op("add"), input, concatb);
        auto relu    = mm2->add_instruction(migraphx::make_op("relu"), sum);
        auto rsp     = mm2->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 8}}}), relu);
        auto slc1    = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {4}}}), rsp);
        auto slc2 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {4}}, {"ends", {8}}}), rsp);
        auto add = mm2->add_instruction(migraphx::make_op("add"), slc1, slc2);
1023
        mm2->add_instruction(pass_op{}, add);
1024
1025
1026
1027
1028
1029
1030
1031
1032
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_slice_different_axis)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 2}};
    migraphx::program p1;
    {
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
        auto* mm1  = p1.get_main_module();
        auto r     = migraphx::op::reshape{{3, 2, 4}};
        auto input = mm1->add_parameter("input", s);
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto one  = mm1->add_literal(1);
        auto oneb = mm1->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {3, 1, 4, 2}}}), one);
        auto two  = mm1->add_literal(2);
        auto twob = mm1->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 3}, {"dims", {3, 2, 4, 1}}}), two);
        auto sum1     = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1    = mm1->add_instruction(migraphx::make_op("relu"), sum1);
1048
        auto reshape1 = mm1->add_instruction(r, relu1);
1049
1050
        auto sum2     = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2    = mm1->add_instruction(migraphx::make_op("relu"), sum2);
1051
        auto reshape2 = mm1->add_instruction(r, relu2);
1052
        auto add      = mm1->add_instruction(migraphx::make_op("add"), reshape1, reshape2);
1053
        mm1->add_instruction(pass_op{}, add);
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
    }
    migraphx::program p2 = p1;
    run_pass(p1);

    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_slice_missing_begining_slice)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
    migraphx::program p1;
    {
1066
        auto* mm1  = p1.get_main_module();
1067
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
1068
        auto input = mm1->add_parameter("input", s);
1069
1070
1071
1072
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
1073
1074
1075
1076
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
1077
1078
1079
1080
1081
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto add   = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
1082
        mm1->add_instruction(pass_op{}, add);
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
    }
    migraphx::program p2 = p1;
    run_pass(p1);

    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_slice_missing_middle_slice)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
    migraphx::program p1;
    {
1095
        auto* mm1  = p1.get_main_module();
1096
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
1097
        auto input = mm1->add_parameter("input", s);
1098
1099
1100
1101
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
1102
1103
1104
1105
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
1106
1107
1108
1109
1110
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto add   = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
1111
        mm1->add_instruction(pass_op{}, add);
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
    }
    migraphx::program p2 = p1;
    run_pass(p1);

    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_slice_missing_end_slice)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
    migraphx::program p1;
    {
1124
        auto* mm1  = p1.get_main_module();
1125
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
1126
        auto input = mm1->add_parameter("input", s);
1127
1128
1129
1130
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
1131
1132
1133
1134
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
1135
1136
1137
1138
1139
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto add   = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
1140
        mm1->add_instruction(pass_op{}, add);
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
    }
    migraphx::program p2 = p1;
    run_pass(p1);

    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_split_add_relu_concat_same_axis)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
    migraphx::program p1;
    {
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
        auto* mm1  = p1.get_main_module();
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
        auto input = mm1->add_parameter("input", s);
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto concat =
            mm1->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), relu1, relu2);
1170
        mm1->add_instruction(pass_op{}, concat);
1171
1172
1173
1174
1175
    }
    run_pass(p1);

    migraphx::program p2;
    {
1176
        auto* mm2    = p2.get_main_module();
1177
        auto b       = migraphx::op::broadcast{1, {3, 2, 4}};
1178
1179
1180
        auto input   = mm2->add_parameter("input", s);
        auto one     = mm2->add_literal(1);
        auto two     = mm2->add_literal(2);
1181
        auto concat  = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
1182
        auto concatb = mm2->add_instruction(b, concat);
1183
1184
        auto sum     = mm2->add_instruction(migraphx::make_op("add"), input, concatb);
        auto relu    = mm2->add_instruction(migraphx::make_op("relu"), sum);
1185
        mm2->add_instruction(pass_op{}, relu);
1186
1187
1188
1189
1190
1191
1192
1193
1194
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_split_add_relu_multi_axes)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 6}};
    migraphx::program p1;
    {
1195
        auto* mm1  = p1.get_main_module();
1196
        auto b     = migraphx::op::broadcast{1, {3, 1, 4, 3}};
1197
        auto input = mm1->add_parameter("input", s);
1198
1199
1200
1201
1202
1203
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1, 3}}, {"starts", {0, 0}}, {"ends", {1, 3}}}),
            input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1, 3}}, {"starts", {1, 3}}, {"ends", {2, 6}}}),
            input);
1204
1205
1206
1207
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
1208
1209
1210
1211
1212
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto add   = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
1213
        mm1->add_instruction(pass_op{}, add);
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
    }
    migraphx::program p2 = p1;
    run_pass(p1);
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_split_add_relu_used_multiple_split1)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
    migraphx::program p1;
    {
1225
        auto* mm1  = p1.get_main_module();
1226
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
1227
        auto input = mm1->add_parameter("input", s);
1228
1229
1230
1231
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
1232
1233
1234
1235
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
1236
1237
1238
1239
1240
1241
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto add1  = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
        auto add2  = mm1->add_instruction(migraphx::make_op("add"), x, add1);
1242
        mm1->add_instruction(pass_op{}, add2);
1243
1244
1245
1246
1247
    }
    run_pass(p1);

    migraphx::program p2;
    {
1248
1249
1250
1251
1252
        auto* mm2  = p2.get_main_module();
        auto b     = migraphx::op::broadcast{1, {3, 2, 4}};
        auto input = mm2->add_parameter("input", s);
        auto slice = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
1253
1254
        auto one     = mm2->add_literal(1);
        auto two     = mm2->add_literal(2);
1255
        auto concat  = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
1256
        auto concatb = mm2->add_instruction(b, concat);
1257
1258
1259
1260
1261
1262
1263
1264
        auto sum     = mm2->add_instruction(migraphx::make_op("add"), input, concatb);
        auto relu    = mm2->add_instruction(migraphx::make_op("relu"), sum);
        auto x       = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), relu);
        auto y = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), relu);
        auto add1 = mm2->add_instruction(migraphx::make_op("add"), x, y);
        auto add2 = mm2->add_instruction(migraphx::make_op("add"), slice, add1);
1265
        mm2->add_instruction(pass_op{}, add2);
1266
1267
1268
1269
1270
1271
1272
1273
1274
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_split_add_relu_used_multiple_split2)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
    migraphx::program p1;
    {
1275
        auto* mm1  = p1.get_main_module();
1276
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
1277
        auto input = mm1->add_parameter("input", s);
1278
1279
1280
1281
1282
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
        auto z     = mm1->add_instruction(migraphx::make_op("relu"), x);
1283
1284
1285
1286
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
1287
1288
1289
1290
1291
1292
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto add1  = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
        auto add2  = mm1->add_instruction(migraphx::make_op("add"), z, add1);
1293
        mm1->add_instruction(pass_op{}, add2);
1294
1295
1296
1297
1298
    }
    run_pass(p1);

    migraphx::program p2;
    {
1299
1300
1301
1302
1303
1304
        auto* mm2  = p2.get_main_module();
        auto b     = migraphx::op::broadcast{1, {3, 2, 4}};
        auto input = mm2->add_parameter("input", s);
        auto slice = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto z       = mm2->add_instruction(migraphx::make_op("relu"), slice);
1305
1306
        auto one     = mm2->add_literal(1);
        auto two     = mm2->add_literal(2);
1307
        auto concat  = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
1308
        auto concatb = mm2->add_instruction(b, concat);
1309
1310
1311
1312
1313
1314
1315
1316
        auto sum     = mm2->add_instruction(migraphx::make_op("add"), input, concatb);
        auto relu    = mm2->add_instruction(migraphx::make_op("relu"), sum);
        auto x       = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), relu);
        auto y = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), relu);
        auto add1 = mm2->add_instruction(migraphx::make_op("add"), x, y);
        auto add2 = mm2->add_instruction(migraphx::make_op("add"), z, add1);
1317
        mm2->add_instruction(pass_op{}, add2);
1318
1319
1320
1321
1322
1323
1324
1325
1326
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_split_between_add)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
    migraphx::program p1;
    {
1327
1328
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
1329
1330
1331
1332
1333
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
        auto sum = mm1->add_instruction(migraphx::make_op("add"), x, y);
1334
        mm1->add_instruction(pass_op{}, sum);
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
    }
    migraphx::program p2 = p1;
    run_pass(p1);
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_dot_horiz)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
    migraphx::program p1;
    {
1346
1347
1348
1349
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto a     = mm1->add_literal(migraphx::generate_literal(s, 0));
        auto b     = mm1->add_literal(migraphx::generate_literal(s, 1));
1350
1351
1352
        auto x     = mm1->add_instruction(migraphx::make_op("dot"), input, a);
        auto y     = mm1->add_instruction(migraphx::make_op("dot"), input, b);
        auto sum   = mm1->add_instruction(migraphx::make_op("add"), x, y);
1353
        mm1->add_instruction(pass_op{}, sum);
1354
1355
1356
1357
1358
    }
    run_pass(p1);

    migraphx::program p2;
    {
1359
1360
1361
1362
        auto* mm2   = p2.get_main_module();
        auto input  = mm2->add_parameter("input", s);
        auto a      = mm2->add_literal(migraphx::generate_literal(s, 0));
        auto b      = mm2->add_literal(migraphx::generate_literal(s, 1));
1363
1364
1365
1366
1367
1368
1369
        auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, b);
        auto dot    = mm2->add_instruction(migraphx::make_op("dot"), input, concat);
        auto x      = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot);
        auto y = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), dot);
        auto sum = mm2->add_instruction(migraphx::make_op("add"), x, y);
1370
        mm2->add_instruction(pass_op{}, sum);
1371
1372
1373
1374
1375
1376
1377
1378
1379
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_dot_horiz_same_constant)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
    migraphx::program p1;
    {
1380
1381
1382
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto a     = mm1->add_literal(migraphx::generate_literal(s, 0));
1383
1384
1385
        auto x     = mm1->add_instruction(migraphx::make_op("dot"), input, a);
        auto y     = mm1->add_instruction(migraphx::make_op("dot"), input, a);
        auto sum   = mm1->add_instruction(migraphx::make_op("add"), x, y);
1386
        mm1->add_instruction(pass_op{}, sum);
1387
1388
1389
1390
1391
    }
    run_pass(p1);

    migraphx::program p2;
    {
1392
1393
1394
        auto* mm2   = p2.get_main_module();
        auto input  = mm2->add_parameter("input", s);
        auto a      = mm2->add_literal(migraphx::generate_literal(s, 0));
1395
1396
1397
1398
1399
1400
1401
        auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, a);
        auto dot    = mm2->add_instruction(migraphx::make_op("dot"), input, concat);
        auto x      = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot);
        auto y = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), dot);
        auto sum = mm2->add_instruction(migraphx::make_op("add"), x, y);
1402
        mm2->add_instruction(pass_op{}, sum);
1403
1404
1405
1406
1407
1408
1409
1410
1411
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_dot_horiz_flipped)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
    migraphx::program p1;
    {
1412
1413
1414
1415
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto a     = mm1->add_literal(migraphx::generate_literal(s, 0));
        auto b     = mm1->add_literal(migraphx::generate_literal(s, 1));
1416
1417
1418
        auto x     = mm1->add_instruction(migraphx::make_op("dot"), input, a);
        auto y     = mm1->add_instruction(migraphx::make_op("dot"), b, input);
        auto sum   = mm1->add_instruction(migraphx::make_op("add"), x, y);
1419
        mm1->add_instruction(pass_op{}, sum);
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
    }

    migraphx::program p2 = p1;
    run_pass(p1);
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_conv_horiz)
{
    auto s  = migraphx::shape{migraphx::shape::int32_type, {8, 3, 64, 64}};
    auto ws = migraphx::shape{migraphx::shape::int32_type, {12, 3, 3, 3}};
    migraphx::program p1;
    {
1433
1434
1435
1436
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto a     = mm1->add_literal(migraphx::generate_literal(ws, 0));
        auto b     = mm1->add_literal(migraphx::generate_literal(ws, 1));
1437
1438
1439
        auto x     = mm1->add_instruction(migraphx::make_op("convolution"), input, a);
        auto y     = mm1->add_instruction(migraphx::make_op("convolution"), input, b);
        auto sum   = mm1->add_instruction(migraphx::make_op("add"), x, y);
1440
        mm1->add_instruction(pass_op{}, sum);
1441
1442
1443
1444
1445
    }
    run_pass(p1);

    migraphx::program p2;
    {
1446
1447
1448
1449
        auto* mm2   = p2.get_main_module();
        auto input  = mm2->add_parameter("input", s);
        auto a      = mm2->add_literal(migraphx::generate_literal(ws, 0));
        auto b      = mm2->add_literal(migraphx::generate_literal(ws, 1));
1450
1451
1452
1453
1454
1455
1456
        auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), a, b);
        auto conv   = mm2->add_instruction(migraphx::make_op("convolution"), input, concat);
        auto x      = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}), conv);
        auto y = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}), conv);
        auto sum = mm2->add_instruction(migraphx::make_op("add"), x, y);
1457
        mm2->add_instruction(pass_op{}, sum);
1458
1459
1460
1461
    }
    EXPECT(p1.sort() == p2.sort());
}

1462
1463
1464
1465
1466
1467
TEST_CASE(simplify_group_conv_horiz)
{
    auto s  = migraphx::shape{migraphx::shape::int32_type, {1, 32, 111, 111}};
    auto ws = migraphx::shape{migraphx::shape::int32_type, {32, 1, 7, 7}};
    migraphx::program p1;
    {
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
        auto* mm1  = p1.get_main_module();
        auto x     = mm1->add_parameter("x", s);
        auto w1    = mm1->add_literal(migraphx::generate_literal(ws, 1));
        auto w2    = mm1->add_literal(migraphx::generate_literal(ws, 2));
        auto conv1 = mm1->add_instruction(
            migraphx::make_op(
                "convolution",
                {{"padding", {3, 3}}, {"stride", {2, 2}}, {"dilation", {1, 1}}, {"group", 32}}),
            x,
            w1);
        auto conv2 = mm1->add_instruction(
            migraphx::make_op(
                "convolution",
                {{"padding", {3, 3}}, {"stride", {2, 2}}, {"dilation", {1, 1}}, {"group", 32}}),
            x,
            w2);
1484
        mm1->add_instruction(pass_op{}, conv1, conv2);
1485
1486
1487
1488
1489
1490
1491
1492
    }
    migraphx::program p2 = p1;
    run_pass(p1);

    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_conv_horiz_grouped)
1493
1494
1495
1496
1497
1498
{
    auto s   = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
    auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}};
    auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
    migraphx::program p1;
    {
1499
1500
1501
1502
1503
1504
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto a     = mm1->add_literal(migraphx::generate_literal(ws1, 0));
        auto b     = mm1->add_literal(migraphx::generate_literal(ws1, 1));
        auto c     = mm1->add_literal(migraphx::generate_literal(ws2, 2));
        auto d     = mm1->add_literal(migraphx::generate_literal(ws2, 3));
1505
1506
1507
1508
1509
1510
1511
1512
1513
        auto convx =
            mm1->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, a);
        auto convy =
            mm1->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, b);
        auto dotx = mm1->add_instruction(migraphx::make_op("dot"), input, c);
        auto doty = mm1->add_instruction(migraphx::make_op("dot"), input, d);
        auto sum1 = mm1->add_instruction(migraphx::make_op("add"), convx, convy);
        auto sum2 = mm1->add_instruction(migraphx::make_op("add"), dotx, doty);
        auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
1514
1515

        mm1->add_instruction(pass_op{}, sum3);
1516
1517
1518
1519
1520
    }
    run_pass(p1);

    migraphx::program p2;
    {
1521
1522
1523
1524
1525
1526
        auto* mm2    = p2.get_main_module();
        auto input   = mm2->add_parameter("input", s);
        auto a       = mm2->add_literal(migraphx::generate_literal(ws1, 0));
        auto b       = mm2->add_literal(migraphx::generate_literal(ws1, 1));
        auto c       = mm2->add_literal(migraphx::generate_literal(ws2, 2));
        auto d       = mm2->add_literal(migraphx::generate_literal(ws2, 3));
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
        auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), a, b);
        auto concat2 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), c, d);
        auto conv    = mm2->add_instruction(
            migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, concat1);
        auto convx = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv);
        auto convy = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv);
        auto sum1 = mm2->add_instruction(migraphx::make_op("add"), convx, convy);
        auto dot  = mm2->add_instruction(migraphx::make_op("dot"), input, concat2);
        auto dotx = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot);
        auto doty = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot);
        auto sum2 = mm2->add_instruction(migraphx::make_op("add"), dotx, doty);
        auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum1, sum2);
1543
        mm2->add_instruction(pass_op{}, sum3);
1544
1545
1546
1547
    }
    EXPECT(p1.sort() == p2.sort());
}

1548
TEST_CASE(simplify_conv_horiz_grouped_extra1)
1549
1550
1551
1552
1553
1554
{
    auto s   = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
    auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}};
    auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
    migraphx::program p1;
    {
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto a     = mm1->add_literal(migraphx::generate_literal(ws1, 0));
        auto b     = mm1->add_literal(migraphx::generate_literal(ws1, 1));
        auto c     = mm1->add_literal(migraphx::generate_literal(ws2, 2));
        auto d     = mm1->add_literal(migraphx::generate_literal(ws2, 3));
        auto e     = mm1->add_literal(migraphx::generate_literal(s, 4));
        auto convx =
            mm1->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, a);
        auto convy =
            mm1->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, b);
        auto dotx    = mm1->add_instruction(migraphx::make_op("dot"), input, c);
        auto doty    = mm1->add_instruction(migraphx::make_op("dot"), input, d);
        auto sqdiffx = mm1->add_instruction(migraphx::make_op("sqdiff"), input, e);
        auto sum1    = mm1->add_instruction(migraphx::make_op("add"), convx, convy);
        auto sum2    = mm1->add_instruction(migraphx::make_op("add"), dotx, doty);
1571
        auto sum3    = sqdiffx;
1572
1573
        auto sum4    = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
        auto sum5    = mm1->add_instruction(migraphx::make_op("add"), sum4, sum3);
1574
        mm1->add_instruction(pass_op{}, sum5);
1575
1576
1577
1578
1579
    }
    run_pass(p1);

    migraphx::program p2;
    {
1580
1581
1582
1583
1584
1585
1586
        auto* mm2    = p2.get_main_module();
        auto input   = mm2->add_parameter("input", s);
        auto a       = mm2->add_literal(migraphx::generate_literal(ws1, 0));
        auto b       = mm2->add_literal(migraphx::generate_literal(ws1, 1));
        auto c       = mm2->add_literal(migraphx::generate_literal(ws2, 2));
        auto d       = mm2->add_literal(migraphx::generate_literal(ws2, 3));
        auto e       = mm2->add_literal(migraphx::generate_literal(s, 4));
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
        auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), a, b);
        auto concat2 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), c, d);
        auto conv    = mm2->add_instruction(
            migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, concat1);
        auto convx = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv);
        auto convy = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv);
        auto sum1 = mm2->add_instruction(migraphx::make_op("add"), convx, convy);
        auto dot  = mm2->add_instruction(migraphx::make_op("dot"), input, concat2);
        auto dotx = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot);
        auto doty = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot);
        auto sum2    = mm2->add_instruction(migraphx::make_op("add"), dotx, doty);
        auto sqdiffx = mm2->add_instruction(migraphx::make_op("sqdiff"), input, e);
1603
        auto sum3    = sqdiffx;
1604
1605
        auto sum4    = mm2->add_instruction(migraphx::make_op("add"), sum1, sum2);
        auto sum5    = mm2->add_instruction(migraphx::make_op("add"), sum4, sum3);
1606
        mm2->add_instruction(pass_op{}, sum5);
1607
1608
1609
1610
    }
    EXPECT(p1.sort() == p2.sort());
}

1611
TEST_CASE(simplify_conv_horiz_grouped_extra2)
1612
1613
1614
1615
1616
1617
{
    auto s   = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
    auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}};
    auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
    migraphx::program p1;
    {
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto a     = mm1->add_literal(migraphx::generate_literal(ws1, 0));
        auto b     = mm1->add_literal(migraphx::generate_literal(ws1, 1));
        auto c     = mm1->add_literal(migraphx::generate_literal(ws2, 2));
        auto d     = mm1->add_literal(migraphx::generate_literal(ws2, 3));
        auto e     = mm1->add_literal(migraphx::generate_literal(s, 4));
        auto f     = mm1->add_literal(migraphx::generate_literal(s, 5));
        auto convx =
            mm1->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, a);
        auto convy =
            mm1->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, b);
        auto dotx    = mm1->add_instruction(migraphx::make_op("dot"), input, c);
        auto doty    = mm1->add_instruction(migraphx::make_op("dot"), input, d);
        auto sqdiffx = mm1->add_instruction(migraphx::make_op("sqdiff"), input, e);
        auto sqdiffy = mm1->add_instruction(migraphx::make_op("sqdiff"), input, f);
        auto sum1    = mm1->add_instruction(migraphx::make_op("add"), convx, convy);
        auto sum2    = mm1->add_instruction(migraphx::make_op("add"), dotx, doty);
        auto sum3    = mm1->add_instruction(migraphx::make_op("add"), sqdiffx, sqdiffy);
        auto sum4    = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
        auto sum5    = mm1->add_instruction(migraphx::make_op("add"), sum4, sum3);
1639
        mm1->add_instruction(pass_op{}, sum5);
1640
1641
1642
1643
1644
    }
    run_pass(p1);

    migraphx::program p2;
    {
1645
1646
1647
1648
1649
1650
1651
1652
        auto* mm2    = p2.get_main_module();
        auto input   = mm2->add_parameter("input", s);
        auto a       = mm2->add_literal(migraphx::generate_literal(ws1, 0));
        auto b       = mm2->add_literal(migraphx::generate_literal(ws1, 1));
        auto c       = mm2->add_literal(migraphx::generate_literal(ws2, 2));
        auto d       = mm2->add_literal(migraphx::generate_literal(ws2, 3));
        auto e       = mm2->add_literal(migraphx::generate_literal(s, 4));
        auto f       = mm2->add_literal(migraphx::generate_literal(s, 5));
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
        auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), a, b);
        auto concat2 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), c, d);
        auto conv    = mm2->add_instruction(
            migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, concat1);
        auto convx = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv);
        auto convy = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv);
        auto sum1 = mm2->add_instruction(migraphx::make_op("add"), convx, convy);
        auto dot  = mm2->add_instruction(migraphx::make_op("dot"), input, concat2);
        auto dotx = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot);
        auto doty = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot);
        auto sum2    = mm2->add_instruction(migraphx::make_op("add"), dotx, doty);
        auto sqdiffx = mm2->add_instruction(migraphx::make_op("sqdiff"), input, e);
        auto sqdiffy = mm2->add_instruction(migraphx::make_op("sqdiff"), input, f);
        auto sum3    = mm2->add_instruction(migraphx::make_op("add"), sqdiffx, sqdiffy);
        auto sum4    = mm2->add_instruction(migraphx::make_op("add"), sum1, sum2);
        auto sum5    = mm2->add_instruction(migraphx::make_op("add"), sum4, sum3);
1673
        mm2->add_instruction(pass_op{}, sum5);
1674
1675
1676
1677
    }
    EXPECT(p1.sort() == p2.sort());
}

1678
1679
1680
1681
TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
{
    migraphx::program p1;
    {
1682
1683
1684
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
        auto w    = mm1->add_literal(
1685
            migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
1686
1687
1688
        auto conv   = mm1->add_instruction(migraphx::make_op("convolution"), x, w);
        auto slice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
1689
        auto a1 =
1690
            mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
1691
1692
1693
        auto b1 = mm1->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a1);
        auto mul = mm1->add_instruction(migraphx::make_op("mul"), slice1, b1);
1694
        auto a2 =
1695
            mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2));
1696
1697
1698
        auto b2 = mm1->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a2);
        auto add1 = mm1->add_instruction(migraphx::make_op("add"), mul, b2);
1699
        auto a3 =
1700
            mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
1701
1702
1703
1704
1705
        auto b3 = mm1->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a3);
        auto slice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv);
        auto add2 = mm1->add_instruction(migraphx::make_op("add"), slice2, b3);
1706
        mm1->add_instruction(pass_op{}, add1, add2);
1707
1708
1709
1710
1711
    }
    run_pass(p1);

    migraphx::program p2;
    {
1712
1713
1714
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
        auto w    = mm2->add_literal(
1715
            migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
1716
1717
        auto wslice1 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {384}}}), w);
1718
        auto a1 =
1719
            mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
1720
1721
1722
1723
1724
1725
1726
1727
        auto b1 = mm2->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 0}, {"dims", {384, 1024, 1, 1}}}), a1);
        auto mul     = mm2->add_instruction(migraphx::make_op("mul"), b1, wslice1);
        auto wslice2 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w);
        auto concat1 =
            mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), mul, wslice2);
        auto conv = mm2->add_instruction(migraphx::make_op("convolution"), x, concat1);
1728
        auto a2 =
1729
            mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2));
1730
        auto a3 =
1731
            mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
1732
1733
1734
1735
1736
1737
1738
1739
        auto concat2 = mm2->add_instruction(migraphx::make_op("concat"), a2, a3);
        auto b4      = mm2->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 768, 17, 17}}}), concat2);
        auto add    = mm2->add_instruction(migraphx::make_op("add"), conv, b4);
        auto slice1 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), add);
        auto slice2 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), add);
1740
        mm2->add_instruction(pass_op{}, slice1, slice2);
1741
1742
1743
    }
    EXPECT(p1.sort() == p2.sort());
}
1744
1745
1746
1747
1748
1749
TEST_CASE(reorder_reshape_slice)
{
    std::vector<int64_t> perm0 = {0, 2, 1, 3};
    std::vector<int64_t> perm1 = {0, 2, 3, 1};
    auto create_p1             = [&](std::size_t batch_size) {
        migraphx::program p1;
1750
        auto* mm1  = p1.get_main_module();
1751
        auto s     = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
1752
        auto input = mm1->add_parameter("input", s);
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
        auto slc0  = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input);
        auto slc1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {640}}, {"ends", {1280}}}),
            input);
        auto slc2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
            input);

        auto c0 = mm1->add_instruction(migraphx::make_op("contiguous"), slc0);
        auto c1 = mm1->add_instruction(migraphx::make_op("contiguous"), slc1);
        auto c2 = mm1->add_instruction(migraphx::make_op("contiguous"), slc2);
1765
1766

        std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 128, 10, 64};
1767
1768
1769
        auto r0 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
        auto r1 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
        auto r2 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
1770

1771
1772
1773
        auto t0 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r0);
        auto t1 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r1);
        auto t2 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), r2);
1774

1775
1776
        auto sum = mm1->add_instruction(migraphx::make_op("add"), t0, t1);
        auto ret = mm1->add_instruction(migraphx::make_op("dot"), sum, t2);
1777
        mm1->add_return({ret});
1778
1779
1780
1781
1782
1783

        return p1;
    };

    auto create_p2 = [&](std::size_t batch_size) {
        migraphx::program p2;
1784
        auto* mm2  = p2.get_main_module();
1785
        auto s     = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
1786
        auto input = mm2->add_parameter("input", s);
1787
        std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 128, 30, 64};
1788
        auto r = mm2->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
1789

1790
1791
1792
1793
1794
1795
        auto slc0 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {10}}}), r);
        auto slc1 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {10}}, {"ends", {20}}}), r);
        auto slc2 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {20}}, {"ends", {30}}}), r);
1796

1797
1798
1799
        auto t0 = mm2->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc0);
        auto t1 = mm2->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc1);
        auto t2 = mm2->add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), slc2);
1800

1801
1802
        auto sum = mm2->add_instruction(migraphx::make_op("add"), t0, t1);
        auto ret = mm2->add_instruction(migraphx::make_op("dot"), sum, t2);
1803
        mm2->add_return({ret});
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819

        return p2;
    };

    auto test = [&](std::size_t batch_size) {
        auto p1 = create_p1(batch_size);
        run_pass(p1);
        auto p2 = create_p2(batch_size);
        EXPECT(p1.sort() == p2.sort());
    };

    test(1);
    test(4);
    test(8);
}

1820
TEST_CASE(reorder_reshape_slice_move_axis1)
1821
1822
1823
{
    auto create_p1 = [](std::size_t batch_size) {
        migraphx::program p1;
1824
1825
        auto* mm1 = p1.get_main_module();
        auto s    = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}};
1826
1827
        std::vector<int64_t> perm0 = {0, 2, 1, 3};
        std::vector<int64_t> perm1 = {0, 2, 3, 1};
1828
        auto input                 = mm1->add_parameter("input", s);
1829
1830
1831
1832
1833
1834
        auto slc0                  = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), input);
        auto slc1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), input);
        auto slc2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {96}}}), input);
1835

1836
1837
1838
        auto c0 = mm1->add_instruction(migraphx::make_op("contiguous"), slc0);
        auto c1 = mm1->add_instruction(migraphx::make_op("contiguous"), slc1);
        auto c2 = mm1->add_instruction(migraphx::make_op("contiguous"), slc2);
1839

1840
        std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 64, 4, 32};
1841
1842
1843
        auto r0 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
        auto r1 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
        auto r2 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
1844

1845
1846
1847
        auto t0 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r0);
        auto t1 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r1);
        auto t2 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), r2);
1848

1849
1850
        auto sum = mm1->add_instruction(migraphx::make_op("add"), t0, t1);
        auto ret = mm1->add_instruction(migraphx::make_op("dot"), sum, t2);
1851
        mm1->add_return({ret});
1852
1853
1854
1855

        return p1;
    };

1856
1857
    auto create_p2 = [](std::size_t batch_size) {
        migraphx::program p;
1858
1859
        auto* mm = p.get_main_module();
        auto s   = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}};
1860
1861
        std::vector<int64_t> perm0 = {0, 2, 1, 3};
        std::vector<int64_t> perm1 = {0, 2, 3, 1};
1862
        auto input                 = mm->add_parameter("input", s);
1863
        std::vector<int64_t> lens  = {static_cast<int64_t>(batch_size), 64, 4, 96};
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
        auto rsp  = mm->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
        auto slc0 = mm->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp);
        auto t0   = mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc0);
        auto slc1 = mm->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp);
        auto t1   = mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc1);
        auto slc2 = mm->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp);
        auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), slc2);

        auto sum = mm->add_instruction(migraphx::make_op("add"), t0, t1);
        auto ret = mm->add_instruction(migraphx::make_op("dot"), sum, t2);
1877
        mm->add_return({ret});
1878
1879
1880
1881

        return p;
    };

1882
1883
    auto test = [&](std::size_t batch_size) {
        auto p1 = create_p1(batch_size);
1884
        auto p2 = create_p2(batch_size);
1885
1886
1887
1888
1889
1890
1891
1892
        run_pass(p1);
        EXPECT(p1.sort() == p2.sort());
    };

    test(4);
    test(8);
}

1893
1894
1895
1896
TEST_CASE(reorder_reshape_slice_move_axis2)
{
    auto create_p1 = [] {
        migraphx::program p1;
1897
        auto* mm1 = p1.get_main_module();
1898
        migraphx::shape s{migraphx::shape::float_type, {128, 96}};
1899
        auto input = mm1->add_parameter("input", s);
1900
1901
1902
1903
1904
1905
        auto slc0  = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {32}}}), input);
        auto slc1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {32}}, {"ends", {64}}}), input);
        auto slc2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {96}}}), input);
1906

1907
1908
1909
        auto c0 = mm1->add_instruction(migraphx::make_op("contiguous"), slc0);
        auto c1 = mm1->add_instruction(migraphx::make_op("contiguous"), slc1);
        auto c2 = mm1->add_instruction(migraphx::make_op("contiguous"), slc2);
1910
1911

        std::vector<int64_t> lens = {1, 16, 8, 32};
1912
1913
1914
        auto r0 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
        auto r1 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
        auto r2 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
1915

1916
1917
        auto sum = mm1->add_instruction(migraphx::make_op("add"), r0, r1);
        auto ret = mm1->add_instruction(migraphx::make_op("mul"), sum, r2);
1918
        mm1->add_return({ret});
1919
1920
1921
1922
1923
1924

        return p1;
    };

    auto create_p2 = [] {
        migraphx::program p;
1925
        auto* mm                  = p.get_main_module();
1926
        auto s                    = migraphx::shape{migraphx::shape::float_type, {128, 96}};
1927
        auto input                = mm->add_parameter("input", s);
1928
        std::vector<int64_t> lens = {1, 16, 8, 96};
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
        auto rsp  = mm->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
        auto slc0 = mm->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp);
        auto slc1 = mm->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp);
        auto slc2 = mm->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp);

        auto sum = mm->add_instruction(migraphx::make_op("add"), slc0, slc1);
        auto ret = mm->add_instruction(migraphx::make_op("mul"), sum, slc2);
1939
        mm->add_return({ret});
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953

        return p;
    };

    auto p1 = create_p1();
    auto p2 = create_p2();
    run_pass(p1);
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(reorder_reshape_slice_not_apply)
{
    auto create_p = [] {
        migraphx::program p;
1954
        auto* mm = p.get_main_module();
1955
        migraphx::shape s{migraphx::shape::float_type, {128, 96}};
1956
        auto input = mm->add_parameter("input", s);
1957
1958
1959
1960
1961
1962
        auto slc0  = mm->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {32}}}), input);
        auto slc1 = mm->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {32}}, {"ends", {64}}}), input);
        auto slc2 = mm->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {96}}}), input);
1963

1964
1965
1966
        auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), slc0);
        auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), slc1);
        auto c2 = mm->add_instruction(migraphx::make_op("contiguous"), slc2);
1967
1968

        std::vector<int64_t> lens = {1, 16, 16, 16};
1969
1970
1971
        auto r0 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
        auto r1 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
        auto r2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
1972

1973
1974
        auto sum = mm->add_instruction(migraphx::make_op("add"), r0, r1);
        auto ret = mm->add_instruction(migraphx::make_op("mul"), sum, r2);
1975
        mm->add_return({ret});
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985

        return p;
    };

    auto p1 = create_p();
    auto p2 = p1;
    run_pass(p1);
    EXPECT(p1.sort() == p2.sort());
}

1986
1987
1988
1989
TEST_CASE(reorder_reshape_slice_diff_dims)
{
    auto create_p1 = [](std::size_t batch_size) {
        migraphx::program p1;
1990
1991
        auto* mm1 = p1.get_main_module();
        auto s    = migraphx::shape{migraphx::shape::float_type, {batch_size, 96, 96}};
1992
1993
        std::vector<int64_t> perm0 = {0, 2, 1, 3};
        std::vector<int64_t> perm1 = {0, 2, 3, 1};
1994
        auto input                 = mm1->add_parameter("input", s);
1995
1996
1997
1998
1999
2000
        auto slc0                  = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), input);
        auto slc1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), input);
        auto slc2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {96}}}), input);
2001

2002
2003
2004
        auto c0 = mm1->add_instruction(migraphx::make_op("contiguous"), slc0);
        auto c1 = mm1->add_instruction(migraphx::make_op("contiguous"), slc1);
        auto c2 = mm1->add_instruction(migraphx::make_op("contiguous"), slc2);
2005
2006
2007

        std::vector<int64_t> lens  = {static_cast<int64_t>(batch_size), 32, 3, 32};
        std::vector<int64_t> lens1 = {static_cast<int64_t>(batch_size), 48, 2, 32};
2008
2009
2010
        auto r0 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
        auto r1 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
        auto r2 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens1}}), c2);
2011

2012
        mm1->add_return({r0, r1, r2});
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032

        return p1;
    };

    auto test = [&](std::size_t batch_size) {
        auto p1 = create_p1(batch_size);
        auto p2 = p1;
        run_pass(p1);
        EXPECT(p1.sort() == p2.sort());
    };

    test(4);
    test(8);
}

TEST_CASE(reorder_slice_trans)
{
    std::vector<int64_t> perm = {0, 2, 1};
    auto create_p1            = [&](std::size_t batch_size) {
        migraphx::program p1;
2033
        auto* mm1  = p1.get_main_module();
2034
        auto s     = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
2035
        auto input = mm1->add_parameter("input", s);
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
        auto slc0  = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input);
        auto slc1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {640}}, {"ends", {1280}}}),
            input);
        auto slc2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
            input);

        auto t0 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), slc0);
        auto t1 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), slc1);
        auto t2 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), slc2);

        auto sum = mm1->add_instruction(migraphx::make_op("add"), t0, t1);
        auto ret = mm1->add_instruction(migraphx::make_op("mul"), sum, t2);
2051
        mm1->add_return({ret});
2052
2053
2054
2055
2056
2057

        return p1;
    };

    auto create_p2 = [&](std::size_t batch_size) {
        migraphx::program p2;
2058
        auto* mm2  = p2.get_main_module();
2059
        auto s     = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
2060
        auto input = mm2->add_parameter("input", s);
2061
        auto r     = mm2->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), input);
2062

2063
2064
2065
2066
2067
2068
        auto slc0 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {640}}}), r);
        auto slc1 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {640}}, {"ends", {1280}}}), r);
        auto slc2 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1280}}, {"ends", {1920}}}), r);
2069

2070
2071
        auto sum = mm2->add_instruction(migraphx::make_op("add"), slc0, slc1);
        auto ret = mm2->add_instruction(migraphx::make_op("mul"), sum, slc2);
2072
        mm2->add_return({ret});
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091

        return p2;
    };

    auto test = [&](std::size_t batch_size) {
        auto p1 = create_p1(batch_size);
        run_pass(p1);
        auto p2 = create_p2(batch_size);
        EXPECT(p1.sort() == p2.sort());
    };

    test(1);
    test(8);
}

TEST_CASE(reorder_slice_trans_diff_perm)
{
    auto create_p1 = [](std::size_t batch_size) {
        migraphx::program p1;
2092
2093
        auto* mm1 = p1.get_main_module();
        auto s    = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
2094
2095
        std::vector<int64_t> perm0 = {0, 2, 1};
        std::vector<int64_t> perm1 = {0, 1, 2};
2096
        auto input                 = mm1->add_parameter("input", s);
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
        auto slc0                  = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input);
        auto slc1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {640}}, {"ends", {1280}}}),
            input);
        auto slc2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
            input);

        auto t0 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc0);
        auto t1 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc1);
        auto t2 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), slc2);

        auto sum = mm1->add_instruction(migraphx::make_op("add"), t0, t1);
        auto ret = mm1->add_instruction(migraphx::make_op("dot"), sum, t2);
2112
        mm1->add_return({ret});
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127

        return p1;
    };

    auto test = [&](std::size_t batch_size) {
        auto p1 = create_p1(batch_size);
        run_pass(p1);
        auto p2 = p1;
        EXPECT(p1.sort() == p2.sort());
    };

    test(1);
    test(4);
}

Paul's avatar
Paul committed
2128
int main(int argc, const char* argv[]) { test::run(argc, argv); }