simplify_algebra_test.cpp 96.8 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
3
#include <migraphx/pass_manager.hpp>
Paul's avatar
Paul committed
4
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
5
6
7
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
8
#include <basic_ops.hpp>
9
10
#include <migraphx/make_op.hpp>

Paul's avatar
Paul committed
11
12
#include <test.hpp>

13
void run_pass(migraphx::program& p)
Paul's avatar
Paul committed
14
{
15
16
    migraphx::run_passes(*p.get_main_module(),
                         {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}});
17
}
Paul's avatar
Paul committed
18

Paul's avatar
Paul committed
19
TEST_CASE(simplify_add1)
Paul's avatar
Paul committed
20
{
Paul's avatar
Paul committed
21
    migraphx::program p1;
Paul's avatar
Paul committed
22
    {
23
24
25
26
27
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto one  = mm1->add_literal(1);
        auto two  = mm1->add_literal(2);
28
29
30
        auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, one);
        auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, two);
        auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
31
        mm1->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
32
    }
33
    run_pass(p1);
Paul's avatar
Paul committed
34

Paul's avatar
Paul committed
35
    migraphx::program p2;
Paul's avatar
Paul committed
36
    {
37
38
39
40
41
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto one  = mm2->add_literal(1);
        auto two  = mm2->add_literal(2);
42
43
44
        auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, two);
        auto sum2 = mm2->add_instruction(migraphx::make_op("add"), x, y);
        auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum2, sum1);
45
        mm2->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
46
47
48
49
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
50
TEST_CASE(simplify_add2)
Paul's avatar
Paul committed
51
{
Paul's avatar
Paul committed
52
    migraphx::program p1;
Paul's avatar
Paul committed
53
    {
54
55
56
57
58
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto one  = mm1->add_literal(1);
        auto two  = mm1->add_literal(2);
59
60
61
        auto sum1 = mm1->add_instruction(migraphx::make_op("add"), one, x);
        auto sum2 = mm1->add_instruction(migraphx::make_op("add"), two, y);
        auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
62
        mm1->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
63
    }
64
    run_pass(p1);
Paul's avatar
Paul committed
65

Paul's avatar
Paul committed
66
    migraphx::program p2;
Paul's avatar
Paul committed
67
    {
68
69
70
71
72
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto one  = mm2->add_literal(1);
        auto two  = mm2->add_literal(2);
73
74
75
        auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, two);
        auto sum2 = mm2->add_instruction(migraphx::make_op("add"), x, y);
        auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum2, sum1);
76
        mm2->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
77
78
79
80
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
81
TEST_CASE(simplify_add3)
Paul's avatar
Paul committed
82
{
Paul's avatar
Paul committed
83
    migraphx::program p1;
Paul's avatar
Paul committed
84
    {
85
86
87
88
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto one  = mm1->add_literal(1);
        auto two  = mm1->add_literal(2);
89
90
91
        auto sum1 = mm1->add_instruction(migraphx::make_op("add"), one, x);
        auto sum2 = mm1->add_instruction(migraphx::make_op("add"), one, two);
        auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
92
        mm1->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
93
    }
94
    run_pass(p1);
Paul's avatar
Paul committed
95

Paul's avatar
Paul committed
96
    migraphx::program p2;
Paul's avatar
Paul committed
97
    {
98
99
100
101
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto one  = mm2->add_literal(1);
        auto two  = mm2->add_literal(2);
102
103
104
        auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, two);
        auto sum2 = mm2->add_instruction(migraphx::make_op("add"), one, sum1);
        auto sum3 = mm2->add_instruction(migraphx::make_op("add"), x, sum2);
105
        mm2->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
106
107
108
109
    }
    EXPECT(p1 == p2);
}

110
111
112
113
114
115
116
TEST_CASE(simplify_add_broadcast1)
{
    migraphx::shape inner{migraphx::shape::int32_type, {2}};
    migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
    migraphx::op::broadcast b{1, {1, 2, 3, 3}};
    migraphx::program p1;
    {
117
118
119
120
121
122
123
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", outer);
        auto y    = mm1->add_parameter("y", outer);
        auto one  = mm1->add_literal({inner, {1, 1}});
        auto oneb = mm1->add_instruction(b, one);
        auto two  = mm1->add_literal({inner, {2, 2}});
        auto twob = mm1->add_instruction(b, two);
124
125
126
        auto sum1 = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto sum2 = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
127
        mm1->add_instruction(pass_op{}, sum3);
128
    }
129
    run_pass(p1);
130
131
132

    migraphx::program p2;
    {
133
134
135
136
137
        auto* mm2  = p2.get_main_module();
        auto x     = mm2->add_parameter("x", outer);
        auto y     = mm2->add_parameter("y", outer);
        auto one   = mm2->add_literal({inner, {1, 1}});
        auto two   = mm2->add_literal({inner, {2, 2}});
138
        auto sum1  = mm2->add_instruction(migraphx::make_op("add"), one, two);
139
        auto sum1b = mm2->add_instruction(b, sum1);
140
141
        auto sum2  = mm2->add_instruction(migraphx::make_op("add"), x, y);
        auto sum3  = mm2->add_instruction(migraphx::make_op("add"), sum2, sum1b);
142
        mm2->add_instruction(pass_op{}, sum3);
143
144
145
146
147
148
149
150
151
152
153
    }
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_add_broadcast2)
{
    migraphx::shape inner{migraphx::shape::int32_type, {2}};
    migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
    migraphx::op::broadcast b{1, {1, 2, 3, 3}};
    auto create_program = [&] {
        migraphx::program p;
154
155
156
157
158
159
        auto* mm  = p.get_main_module();
        auto x    = mm->add_parameter("x", outer);
        auto y    = mm->add_parameter("y", outer);
        auto one  = mm->add_literal({inner, {1, 1}});
        auto oneb = mm->add_instruction(b, one);
        auto two = mm->add_literal({outer, {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}});
160
161
162
        auto sum1 = mm->add_instruction(migraphx::make_op("add"), x, y);
        auto sum2 = mm->add_instruction(migraphx::make_op("add"), oneb, two);
        auto sum3 = mm->add_instruction(migraphx::make_op("add"), sum2, sum1);
163
        mm->add_instruction(pass_op{}, sum3);
164
165
166
        return p;
    };
    migraphx::program p1 = create_program();
167
    run_pass(p1);
168
169
170
171
172

    migraphx::program p2 = create_program();
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
173
// TODO: Add test case
174
// TEST_CASE(simplify_add4)
Paul's avatar
Paul committed
175
176
void simplify_add4()
{
Paul's avatar
Paul committed
177
    migraphx::program p1;
Paul's avatar
Paul committed
178
    {
179
180
181
182
183
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto one  = mm1->add_literal(1);
        auto two  = mm1->add_literal(2);
184
185
186
        auto sum1 = mm1->add_instruction(migraphx::make_op("add"), one, x);
        auto sum2 = mm1->add_instruction(migraphx::make_op("add"), sum1, y);
        auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum2, two);
187
        mm1->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
188
    }
189
    run_pass(p1);
Paul's avatar
Paul committed
190

Paul's avatar
Paul committed
191
    migraphx::program p2;
Paul's avatar
Paul committed
192
    {
193
194
195
196
197
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto one  = mm2->add_literal(1);
        auto two  = mm2->add_literal(2);
198
199
200
        auto sum1 = mm2->add_instruction(migraphx::make_op("add"), one, two);
        auto sum2 = mm2->add_instruction(migraphx::make_op("add"), x, y);
        auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum2, sum1);
201
        mm2->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
202
203
204
205
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
206
207
208
TEST_CASE(simplify_mul_conv1)
{
    migraphx::program p;
209
210
211
212
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
213
214
215
216
217
218
219
220
221
    auto conv = mm->add_instruction(
        migraphx::make_op("convolution",
                          {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
        x,
        w);
    auto a = mm->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}}));
    auto b = mm->add_instruction(
        migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 256, 14, 14}}}), a);
    auto mul = mm->add_instruction(migraphx::make_op("mul"), conv, b);
222
    mm->add_instruction(pass_op{}, mul);
Paul's avatar
Paul committed
223
    EXPECT(conv->outputs().front()->name() == "mul");
224
    run_pass(p);
Shucai Xiao's avatar
Shucai Xiao committed
225
226
    auto new_conv = std::find_if(
        mm->begin(), mm->end(), [](auto&& ins) { return ins.name() == "convolution"; });
Paul's avatar
Paul committed
227
228
229
    EXPECT(new_conv->outputs().front()->name() != "mul");
}

230
231
232
233
TEST_CASE(simplify_mul_slice_conv1)
{
    migraphx::program p1;
    {
234
235
236
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
        auto w    = mm1->add_literal(
237
            migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
238
239
240
        auto conv   = mm1->add_instruction(migraphx::make_op("convolution"), x, w);
        auto slice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
241
        auto a = mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
242
243
244
245
246
247
        auto b = mm1->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a);
        auto mul    = mm1->add_instruction(migraphx::make_op("mul"), slice1, b);
        auto slice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv);
        auto add = mm1->add_instruction(migraphx::make_op("add"), mul, slice2);
248
        mm1->add_instruction(pass_op{}, add);
249
250
251
252
253
    }
    run_pass(p1);

    migraphx::program p2;
    {
254
255
256
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
        auto w    = mm2->add_literal(
257
            migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
258
259
        auto wslice1 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {384}}}), w);
260
        auto a = mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
261
262
263
264
265
266
267
268
269
270
271
272
273
        auto b = mm2->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 0}, {"dims", {384, 1024, 1, 1}}}), a);
        auto mul     = mm2->add_instruction(migraphx::make_op("mul"), b, wslice1);
        auto wslice2 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w);
        auto concat =
            mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), mul, wslice2);
        auto conv   = mm2->add_instruction(migraphx::make_op("convolution"), x, concat);
        auto slice1 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
        auto slice2 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv);
        auto add = mm2->add_instruction(migraphx::make_op("add"), slice1, slice2);
274
        mm2->add_instruction(pass_op{}, add);
275
276
277
278
279
280
281
282
    }
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_mul_slice_conv_overlapping_slice)
{
    migraphx::program p1;
    {
283
284
285
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
        auto w    = mm1->add_literal(
286
            migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
287
288
289
        auto conv   = mm1->add_instruction(migraphx::make_op("convolution"), x, w);
        auto slice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
290
        auto a = mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
291
292
293
294
295
296
        auto b = mm1->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a);
        auto mul    = mm1->add_instruction(migraphx::make_op("mul"), slice1, b);
        auto slice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {383}}, {"ends", {767}}}), conv);
        auto add = mm1->add_instruction(migraphx::make_op("add"), mul, slice2);
297
        mm1->add_instruction(pass_op{}, add);
298
299
300
301
302
303
304
305
306
307
    }
    migraphx::program p2 = p1;
    run_pass(p1);
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_mul_slice_conv_not_all_slice)
{
    migraphx::program p1;
    {
308
309
310
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
        auto w    = mm1->add_literal(
311
            migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
312
313
314
        auto conv   = mm1->add_instruction(migraphx::make_op("convolution"), x, w);
        auto slice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
315
        auto a = mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
316
317
318
        auto b = mm1->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a);
        auto mul = mm1->add_instruction(migraphx::make_op("mul"), slice1, b);
319
        auto c   = mm1->add_literal(
320
            migraphx::generate_literal({migraphx::shape::int32_type, {1, 768, 17, 17}}));
321
322
        auto add    = mm1->add_instruction(migraphx::make_op("add"), conv, c);
        auto concat = mm1->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), mul, add);
323
        mm1->add_instruction(pass_op{}, concat);
324
325
326
327
328
329
    }
    migraphx::program p2 = p1;
    run_pass(p1);
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
330
331
332
333
TEST_CASE(simplify_mul_add)
{
    migraphx::program p1;
    {
334
335
336
337
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto one  = mm1->add_literal(1);
        auto two  = mm1->add_literal(2);
338
339
        auto sum  = mm1->add_instruction(migraphx::make_op("add"), one, x);
        auto mul  = mm1->add_instruction(migraphx::make_op("mul"), sum, two);
340
        mm1->add_instruction(pass_op{}, mul);
Paul's avatar
Paul committed
341
    }
342
    run_pass(p1);
Paul's avatar
Paul committed
343
344
345

    migraphx::program p2;
    {
346
347
348
349
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto one  = mm2->add_literal(1);
        auto two  = mm2->add_literal(2);
350
351
352
        auto mul1 = mm2->add_instruction(migraphx::make_op("mul"), two, x);
        auto mul2 = mm2->add_instruction(migraphx::make_op("mul"), two, one);
        auto sum  = mm2->add_instruction(migraphx::make_op("add"), mul1, mul2);
353
        mm2->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
354
355
356
357
    }
    EXPECT(p1 == p2);
}

Paul's avatar
Paul committed
358
359
360
361
362
TEST_CASE(simplify_inner_broadcast)
{
    auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
    migraphx::program p1;
    {
363
364
365
366
367
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
        auto xb   = mm1->add_instruction(b, x);
        auto yb   = mm1->add_instruction(b, y);
368
        auto sum  = mm1->add_instruction(migraphx::make_op("add"), xb, yb);
369
        mm1->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
370
    }
371
    run_pass(p1);
Paul's avatar
Paul committed
372
373
374

    migraphx::program p2;
    {
375
376
377
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto y    = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
378
        auto sum  = mm2->add_instruction(migraphx::make_op("add"), x, y);
379
380
        auto sumb = mm2->add_instruction(b, sum);
        mm2->add_instruction(pass_op{}, sumb);
Paul's avatar
Paul committed
381
382
383
384
    }
    EXPECT(p1 == p2);
}

385
386
387
TEST_CASE(simplify_add_conv1)
{
    migraphx::program p;
388
389
390
391
392
393
394
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 28}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 3, 3}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
    auto v = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 3, 3}}));
395
396
397
    auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
    auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v);
    auto sum   = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
398
    mm->add_instruction(pass_op{}, sum);
399
    auto s = p.get_output_shapes().back();
400
    run_pass(p);
401
    EXPECT(s == p.get_output_shapes().back());
Shucai Xiao's avatar
Shucai Xiao committed
402
403
404
    EXPECT(std::count_if(mm->begin(), mm->end(), [](auto&& ins) {
               return ins.name() == "convolution";
           }) == 1);
405
406
407
408
409
}

TEST_CASE(simplify_add_conv_no_fusion_7x7_diff_strides)
{
    migraphx::program p;
410
411
412
413
414
415
416
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 7, 7}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
    auto v = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 7, 7}}));
417
418
419
420
    auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
    auto conv2 = mm->add_instruction(
        migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {3, 3}}}), y, v);
    auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
421
    mm->add_instruction(pass_op{}, sum);
422
    auto s = p.get_output_shapes().back();
423
    run_pass(p);
424
    EXPECT(s == p.get_output_shapes().back());
425
    // No fusion
Shucai Xiao's avatar
Shucai Xiao committed
426
427
428
    EXPECT(std::count_if(mm->begin(), mm->end(), [](auto&& ins) {
               return ins.name() == "convolution";
           }) == 2);
429
430
431
432
433
}

TEST_CASE(simplify_add_conv_1x1_diff_strides1)
{
    migraphx::program p;
434
435
436
437
438
439
440
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
    auto v = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
441
442
443
444
    auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
    auto conv2 = mm->add_instruction(
        migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v);
    auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
445
    mm->add_instruction(pass_op{}, sum);
446
    auto s = p.get_output_shapes().back();
447
    run_pass(p);
448
    EXPECT(s == p.get_output_shapes().back());
Shucai Xiao's avatar
Shucai Xiao committed
449
450
451
    EXPECT(std::count_if(mm->begin(), mm->end(), [](auto&& ins) {
               return ins.name() == "convolution";
           }) == 1);
452
453
454
455
456
}

TEST_CASE(simplify_add_conv_1x1_diff_strides2)
{
    migraphx::program p;
457
458
459
460
461
462
463
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 28}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 14, 14}});
    auto v = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
464
465
466
467
    auto conv1 = mm->add_instruction(
        migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), x, w);
    auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v);
    auto sum   = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
468
    mm->add_instruction(pass_op{}, sum);
469
    auto s = p.get_output_shapes().back();
470
471
    run_pass(p);
    EXPECT(s == p.get_output_shapes().back());
Shucai Xiao's avatar
Shucai Xiao committed
472
473
474
    EXPECT(std::count_if(mm->begin(), mm->end(), [](auto&& ins) {
               return ins.name() == "convolution";
           }) == 1);
475
476
477
478
479
}

TEST_CASE(simplify_add_conv_1x1_diff_strides_odd)
{
    migraphx::program p;
480
481
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 54, 83, 83}});
482
    auto w =
483
484
        mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {54, 54, 1, 1}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 54, 165, 165}});
485
    auto v =
486
        mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {54, 54, 1, 1}}));
487
488
489
490
    auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
    auto conv2 = mm->add_instruction(
        migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v);
    auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
491
    mm->add_instruction(pass_op{}, sum);
492
    auto s = p.get_output_shapes().back();
493
    run_pass(p);
494
    EXPECT(s == p.get_output_shapes().back());
Shucai Xiao's avatar
Shucai Xiao committed
495
496
497
    EXPECT(std::count_if(mm->begin(), mm->end(), [](auto&& ins) {
               return ins.name() == "convolution";
           }) == 1);
498
499
500
501
502
}

TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides1)
{
    migraphx::program p;
503
504
505
506
507
508
509
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 14}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 14, 14}});
    auto v = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
510
511
512
513
    auto conv1 = mm->add_instruction(
        migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 1}}}), x, w);
    auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v);
    auto sum   = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
514
    mm->add_instruction(pass_op{}, sum);
515
    auto s = p.get_output_shapes().back();
516
    run_pass(p);
517
    EXPECT(s == p.get_output_shapes().back());
518
    // No fusion
Shucai Xiao's avatar
Shucai Xiao committed
519
520
521
    EXPECT(std::count_if(mm->begin(), mm->end(), [](auto&& ins) {
               return ins.name() == "convolution";
           }) == 2);
522
523
524
525
526
}

TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides2)
{
    migraphx::program p;
527
528
529
530
531
532
533
    auto* mm = p.get_main_module();
    auto x   = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}});
    auto w   = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
    auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 14}});
    auto v = mm->add_literal(
        migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
534
535
536
537
    auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
    auto conv2 = mm->add_instruction(
        migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 1}}}), y, v);
    auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
538
    mm->add_instruction(pass_op{}, sum);
539
    auto s = p.get_output_shapes().back();
540
    run_pass(p);
541
    EXPECT(s == p.get_output_shapes().back());
542
    // No fusion
Shucai Xiao's avatar
Shucai Xiao committed
543
544
545
    EXPECT(std::count_if(mm->begin(), mm->end(), [](auto&& ins) {
               return ins.name() == "convolution";
           }) == 2);
546
547
}

548
549
550
551
552
TEST_CASE(simplify_concat_add_relu)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {1}};
    migraphx::program p1;
    {
553
554
555
556
557
558
559
560
561
562
563
        auto* mm1  = p1.get_main_module();
        auto x     = mm1->add_parameter("x", s);
        auto y     = mm1->add_parameter("y", s);
        auto one   = mm1->add_literal({s, {1}});
        auto two   = mm1->add_literal({s, {2}});
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, one);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, two);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto concat =
            mm1->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), relu1, relu2);
564
        mm1->add_instruction(pass_op{}, concat);
565
566
567
568
569
    }
    run_pass(p1);

    migraphx::program p2;
    {
570
571
572
573
574
        auto* mm2    = p2.get_main_module();
        auto x       = mm2->add_parameter("x", s);
        auto y       = mm2->add_parameter("y", s);
        auto one     = mm2->add_literal({s, {1}});
        auto two     = mm2->add_literal({s, {2}});
575
576
577
578
        auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y);
        auto concat2 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
        auto sum     = mm2->add_instruction(migraphx::make_op("add"), concat1, concat2);
        auto relu    = mm2->add_instruction(migraphx::make_op("relu"), sum);
579
        mm2->add_instruction(pass_op{}, relu);
580
581
582
583
    }
    EXPECT(p1 == p2);
}

584
585
586
587
588
TEST_CASE(simplify_concat_add_relu_partial)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {1}};
    migraphx::program p1;
    {
589
590
591
592
593
594
595
596
597
598
599
600
        auto* mm1  = p1.get_main_module();
        auto x     = mm1->add_parameter("x", s);
        auto y     = mm1->add_parameter("y", s);
        auto one   = mm1->add_literal({s, {1}});
        auto two   = mm1->add_literal({s, {2}});
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, one);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, two);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto sum3  = mm1->add_instruction(migraphx::make_op("add"), x, y);
        auto concat =
            mm1->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), sum3, relu1, relu2);
601
        mm1->add_instruction(pass_op{}, concat);
602
603
604
605
606
    }
    run_pass(p1);

    migraphx::program p2;
    {
607
608
609
610
611
        auto* mm2    = p2.get_main_module();
        auto x       = mm2->add_parameter("x", s);
        auto y       = mm2->add_parameter("y", s);
        auto one     = mm2->add_literal({s, {1}});
        auto two     = mm2->add_literal({s, {2}});
612
613
614
615
616
617
        auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y);
        auto concat2 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
        auto sum1    = mm2->add_instruction(migraphx::make_op("add"), concat1, concat2);
        auto relu    = mm2->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2    = mm2->add_instruction(migraphx::make_op("add"), x, y);
        auto concat  = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), sum2, relu);
618
        mm2->add_instruction(pass_op{}, concat);
619
620
621
622
623
624
625
626
627
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_concat_add_relu_partial_broadcast)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
    migraphx::program p1;
    {
628
629
630
631
632
633
634
635
636
637
638
        auto* mm1 = p1.get_main_module();
        auto b    = migraphx::op::broadcast{1, {2, 1, 4, 5}};
        auto x    = mm1->add_parameter("x", s);
        auto y    = mm1->add_parameter("y", s);
        auto one  = mm1->add_literal(1);
        auto oneb = mm1->add_instruction(b, one);
        auto two  = mm1->add_literal(2);
        auto twob = mm1->add_instruction(b, two);
        auto sum  = mm1->add_instruction(migraphx::make_op("add"), x, y);
        auto concat =
            mm1->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), sum, oneb, twob);
639
        mm1->add_instruction(pass_op{}, concat);
640
641
642
643
644
    }
    run_pass(p1);

    migraphx::program p2;
    {
645
        auto* mm2    = p2.get_main_module();
646
        auto b       = migraphx::op::broadcast{1, {2, 2, 4, 5}};
647
648
649
650
        auto x       = mm2->add_parameter("x", s);
        auto y       = mm2->add_parameter("y", s);
        auto one     = mm2->add_literal(1);
        auto two     = mm2->add_literal(2);
651
        auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
652
        auto concatb = mm2->add_instruction(b, concat1);
653
654
655
        auto sum     = mm2->add_instruction(migraphx::make_op("add"), x, y);
        auto concat2 =
            mm2->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), sum, concatb);
656
        mm2->add_instruction(pass_op{}, concat2);
657
658
659
660
    }
    EXPECT(p1.sort() == p2.sort());
}

661
662
663
664
665
TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
    migraphx::program p1;
    {
666
667
668
669
670
671
672
673
674
675
676
677
678
679
        auto* mm1  = p1.get_main_module();
        auto b     = migraphx::op::broadcast{1, {2, 1, 4, 5}};
        auto x     = mm1->add_parameter("x", s);
        auto y     = mm1->add_parameter("y", s);
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto concat =
            mm1->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), relu1, relu2);
680
        mm1->add_instruction(pass_op{}, concat);
681
682
683
684
685
    }
    run_pass(p1);

    migraphx::program p2;
    {
686
        auto* mm2     = p2.get_main_module();
687
        auto b        = migraphx::op::broadcast{1, {2, 2, 4, 5}};
688
689
690
691
        auto x        = mm2->add_parameter("x", s);
        auto y        = mm2->add_parameter("y", s);
        auto one      = mm2->add_literal(1);
        auto two      = mm2->add_literal(2);
692
693
        auto concat1  = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y);
        auto concat2  = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
694
        auto concat2b = mm2->add_instruction(b, concat2);
695
696
        auto sum      = mm2->add_instruction(migraphx::make_op("add"), concat1, concat2b);
        auto relu     = mm2->add_instruction(migraphx::make_op("relu"), sum);
697
        mm2->add_instruction(pass_op{}, relu);
698
699
700
701
702
703
704
705
706
    }
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
    migraphx::program p1;
    {
707
708
709
710
711
712
713
714
715
716
717
718
719
720
        auto* mm1  = p1.get_main_module();
        auto b     = migraphx::op::broadcast{1, {2, 1, 4, 5}};
        auto x     = mm1->add_parameter("x", s);
        auto y     = mm1->add_parameter("y", s);
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto concat =
            mm1->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), relu1, relu2);
721
        mm1->add_instruction(pass_op{}, concat);
722
723
724
725
726
    }
    run_pass(p1);

    migraphx::program p2;
    {
727
        auto* mm2    = p2.get_main_module();
728
        auto b       = migraphx::op::broadcast{1, {2, 1, 4, 5}};
729
730
731
732
733
734
        auto x       = mm2->add_parameter("x", s);
        auto y       = mm2->add_parameter("y", s);
        auto one     = mm2->add_literal(1);
        auto oneb    = mm2->add_instruction(b, one);
        auto two     = mm2->add_literal(2);
        auto twob    = mm2->add_instruction(b, two);
735
736
737
738
        auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y);
        auto concat2 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), oneb, twob);
        auto sum     = mm2->add_instruction(migraphx::make_op("add"), concat1, concat2);
        auto relu    = mm2->add_instruction(migraphx::make_op("relu"), sum);
739
        mm2->add_instruction(pass_op{}, relu);
740
741
742
743
    }
    EXPECT(p1 == p2);
}

744
745
746
747
TEST_CASE(simplify_div_const)
{
    migraphx::program p1;
    {
748
749
750
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto two  = mm1->add_literal(2);
751
        mm1->add_instruction(migraphx::make_op("div"), x, two);
752
753
754
755
756
    }
    run_pass(p1);

    migraphx::program p2;
    {
757
758
759
        auto* mm2  = p2.get_main_module();
        auto x     = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto two   = mm2->add_literal(2);
760
761
        auto recip = mm2->insert_instruction(std::next(two), migraphx::make_op("recip"), two);
        mm2->add_instruction(migraphx::make_op("mul"), x, recip);
762
763
764
765
766
767
768
769
    }
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_sub_const)
{
    migraphx::program p1;
    {
770
771
772
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto two  = mm1->add_literal(2);
773
        mm1->add_instruction(migraphx::make_op("sub"), x, two);
774
775
776
777
778
    }
    run_pass(p1);

    migraphx::program p2;
    {
779
780
781
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
        auto two  = mm2->add_literal(2);
782
783
        auto neg  = mm2->insert_instruction(std::next(two), migraphx::make_op("neg"), two);
        mm2->add_instruction(migraphx::make_op("add"), x, neg);
784
785
786
787
    }
    EXPECT(p1 == p2);
}

kahmed10's avatar
kahmed10 committed
788
789
790
791
TEST_CASE(simplify_rsqrt)
{
    migraphx::program p1;
    {
792
793
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
794
795
        auto sqrt = mm1->add_instruction(migraphx::make_op("sqrt"), x);
        mm1->add_instruction(migraphx::make_op("recip"), sqrt);
kahmed10's avatar
kahmed10 committed
796
797
798
799
800
    }
    run_pass(p1);

    migraphx::program p2;
    {
801
802
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
803
        mm2->add_instruction(migraphx::make_op("rsqrt"), x);
kahmed10's avatar
kahmed10 committed
804
805
806
807
808
809
810
811
    }
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_rsqrt_multi_use)
{
    migraphx::program p1;
    {
812
813
        auto* mm1  = p1.get_main_module();
        auto x     = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
814
815
816
817
        auto sqrt  = mm1->add_instruction(migraphx::make_op("sqrt"), x);
        auto add   = mm1->add_instruction(migraphx::make_op("add"), sqrt, sqrt);
        auto rsqrt = mm1->add_instruction(migraphx::make_op("recip"), sqrt);
        mm1->add_instruction(migraphx::make_op("add"), rsqrt, add);
kahmed10's avatar
kahmed10 committed
818
819
820
821
822
823
824
    }
    migraphx::program p2{p1};

    run_pass(p1);
    EXPECT(p1 == p2);
}

825
826
827
828
829
830
TEST_CASE(simplify_slice_concat)
{
    auto s = migraphx::shape{migraphx::shape::float_type, {256}};

    migraphx::program p1;
    {
831
832
833
        auto* mm1    = p1.get_main_module();
        auto x       = mm1->add_parameter("x", s);
        auto y       = mm1->add_parameter("y", s);
834
835
836
837
838
839
840
841
842
843
        auto xslice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {128}}}), x);
        auto xslice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {128}}, {"ends", {256}}}), x);
        auto yslice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {128}}}), y);
        auto yslice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {128}}, {"ends", {256}}}), y);
        auto concat = mm1->add_instruction(
            migraphx::make_op("concat", {{"axis", 0}}), xslice1, xslice2, yslice1, yslice2);
844
        mm1->add_instruction(pass_op{}, concat);
845
846
847
848
849
    }
    run_pass(p1);

    migraphx::program p2;
    {
850
851
852
        auto* mm2   = p2.get_main_module();
        auto x      = mm2->add_parameter("x", s);
        auto y      = mm2->add_parameter("y", s);
853
        auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y);
854
        mm2->add_instruction(pass_op{}, concat);
855
856
857
858
859
860
861
862
863
864
    }
    EXPECT(p1 == p2);
}

TEST_CASE(simplify_slice_concat_non_uniform)
{
    auto s = migraphx::shape{migraphx::shape::float_type, {256}};

    migraphx::program p1;
    {
865
866
867
        auto* mm1    = p1.get_main_module();
        auto x       = mm1->add_parameter("x", s);
        auto y       = mm1->add_parameter("y", s);
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
        auto xslice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {64}}}), x);
        auto xslice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {192}}}), x);
        auto xslice3 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {192}}, {"ends", {256}}}), x);
        auto yslice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {64}}}), y);
        auto yslice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {192}}}), y);
        auto yslice3 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {192}}, {"ends", {256}}}), y);
        auto concat = mm1->add_instruction(migraphx::make_op("concat", {{"axis", 0}}),
                                           xslice1,
                                           xslice2,
                                           xslice3,
                                           yslice1,
                                           yslice2,
                                           yslice3);
887
        mm1->add_instruction(pass_op{}, concat);
888
889
890
891
892
    }
    run_pass(p1);

    migraphx::program p2;
    {
893
894
895
        auto* mm2   = p2.get_main_module();
        auto x      = mm2->add_parameter("x", s);
        auto y      = mm2->add_parameter("y", s);
896
        auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y);
897
        mm2->add_instruction(pass_op{}, concat);
898
899
900
901
902
903
904
905
906
907
908
    }

    EXPECT(p1 == p2);
}

TEST_CASE(simplify_slice_concat_flipped)
{
    auto s = migraphx::shape{migraphx::shape::float_type, {256}};

    migraphx::program p1;
    {
909
910
911
        auto* mm1    = p1.get_main_module();
        auto x       = mm1->add_parameter("x", s);
        auto y       = mm1->add_parameter("y", s);
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
        auto xslice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {64}}}), x);
        auto xslice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {192}}, {"ends", {256}}}), x);
        auto xslice3 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {192}}}), x);
        auto yslice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {64}}}), y);
        auto yslice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {192}}, {"ends", {256}}}), y);
        auto yslice3 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {192}}}), y);
        auto concat = mm1->add_instruction(migraphx::make_op("concat", {{"axis", 0}}),
                                           xslice1,
                                           xslice2,
                                           xslice3,
                                           yslice1,
                                           yslice2,
                                           yslice3);
931
        mm1->add_instruction(pass_op{}, concat);
932
933
934
935
936
937
938
    }
    migraphx::program p2 = p1;
    run_pass(p1);

    EXPECT(p1 == p2);
}

939
940
941
942
943
TEST_CASE(simplify_split_add_relu)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
    migraphx::program p1;
    {
944
        auto* mm1  = p1.get_main_module();
945
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
946
        auto input = mm1->add_parameter("input", s);
947
948
949
950
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
951
952
953
954
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
955
956
957
958
959
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto add   = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
960
        mm1->add_instruction(pass_op{}, add);
961
962
963
964
965
    }
    run_pass(p1);

    migraphx::program p2;
    {
966
        auto* mm2    = p2.get_main_module();
967
        auto b       = migraphx::op::broadcast{1, {3, 2, 4}};
968
969
970
        auto input   = mm2->add_parameter("input", s);
        auto one     = mm2->add_literal(1);
        auto two     = mm2->add_literal(2);
971
        auto concat  = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
972
        auto concatb = mm2->add_instruction(b, concat);
973
974
975
976
977
978
979
        auto sum     = mm2->add_instruction(migraphx::make_op("add"), input, concatb);
        auto relu    = mm2->add_instruction(migraphx::make_op("relu"), sum);
        auto x       = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), relu);
        auto y = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), relu);
        auto add = mm2->add_instruction(migraphx::make_op("add"), x, y);
980
        mm2->add_instruction(pass_op{}, add);
981
982
983
984
985
986
987
988
989
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_split_add_relu_reshape)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
    migraphx::program p1;
    {
990
991
992
993
994
995
996
997
        auto* mm1  = p1.get_main_module();
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
        auto r     = migraphx::op::reshape{{3, 4}};
        auto input = mm1->add_parameter("input", s);
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
998
999
1000
1001
        auto one      = mm1->add_literal(1);
        auto oneb     = mm1->add_instruction(b, one);
        auto two      = mm1->add_literal(2);
        auto twob     = mm1->add_instruction(b, two);
1002
1003
        auto sum1     = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1    = mm1->add_instruction(migraphx::make_op("relu"), sum1);
1004
        auto reshape1 = mm1->add_instruction(r, relu1);
1005
1006
        auto sum2     = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2    = mm1->add_instruction(migraphx::make_op("relu"), sum2);
1007
        auto reshape2 = mm1->add_instruction(r, relu2);
1008
        auto add      = mm1->add_instruction(migraphx::make_op("add"), reshape1, reshape2);
1009
        mm1->add_instruction(pass_op{}, add);
1010
1011
1012
1013
1014
    }
    run_pass(p1);

    migraphx::program p2;
    {
1015
        auto* mm2    = p2.get_main_module();
1016
        auto b       = migraphx::op::broadcast{1, {3, 2, 4}};
1017
1018
1019
        auto input   = mm2->add_parameter("input", s);
        auto one     = mm2->add_literal(1);
        auto two     = mm2->add_literal(2);
1020
        auto concat  = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
1021
        auto concatb = mm2->add_instruction(b, concat);
1022
1023
1024
1025
1026
1027
1028
1029
        auto sum     = mm2->add_instruction(migraphx::make_op("add"), input, concatb);
        auto relu    = mm2->add_instruction(migraphx::make_op("relu"), sum);
        auto rsp     = mm2->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 8}}}), relu);
        auto slc1    = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {4}}}), rsp);
        auto slc2 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {4}}, {"ends", {8}}}), rsp);
        auto add = mm2->add_instruction(migraphx::make_op("add"), slc1, slc2);
1030
        mm2->add_instruction(pass_op{}, add);
1031
1032
1033
1034
1035
1036
1037
1038
1039
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_slice_different_axis)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 2}};
    migraphx::program p1;
    {
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
        auto* mm1  = p1.get_main_module();
        auto r     = migraphx::op::reshape{{3, 2, 4}};
        auto input = mm1->add_parameter("input", s);
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto one  = mm1->add_literal(1);
        auto oneb = mm1->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {3, 1, 4, 2}}}), one);
        auto two  = mm1->add_literal(2);
        auto twob = mm1->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 3}, {"dims", {3, 2, 4, 1}}}), two);
        auto sum1     = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1    = mm1->add_instruction(migraphx::make_op("relu"), sum1);
1055
        auto reshape1 = mm1->add_instruction(r, relu1);
1056
1057
        auto sum2     = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2    = mm1->add_instruction(migraphx::make_op("relu"), sum2);
1058
        auto reshape2 = mm1->add_instruction(r, relu2);
1059
        auto add      = mm1->add_instruction(migraphx::make_op("add"), reshape1, reshape2);
1060
        mm1->add_instruction(pass_op{}, add);
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
    }
    migraphx::program p2 = p1;
    run_pass(p1);

    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_slice_missing_begining_slice)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
    migraphx::program p1;
    {
1073
        auto* mm1  = p1.get_main_module();
1074
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
1075
        auto input = mm1->add_parameter("input", s);
1076
1077
1078
1079
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
1080
1081
1082
1083
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
1084
1085
1086
1087
1088
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto add   = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
1089
        mm1->add_instruction(pass_op{}, add);
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
    }
    migraphx::program p2 = p1;
    run_pass(p1);

    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_slice_missing_middle_slice)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
    migraphx::program p1;
    {
1102
        auto* mm1  = p1.get_main_module();
1103
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
1104
        auto input = mm1->add_parameter("input", s);
1105
1106
1107
1108
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
1109
1110
1111
1112
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
1113
1114
1115
1116
1117
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto add   = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
1118
        mm1->add_instruction(pass_op{}, add);
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
    }
    migraphx::program p2 = p1;
    run_pass(p1);

    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_slice_missing_end_slice)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
    migraphx::program p1;
    {
1131
        auto* mm1  = p1.get_main_module();
1132
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
1133
        auto input = mm1->add_parameter("input", s);
1134
1135
1136
1137
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
1138
1139
1140
1141
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
1142
1143
1144
1145
1146
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto add   = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
1147
        mm1->add_instruction(pass_op{}, add);
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
    }
    migraphx::program p2 = p1;
    run_pass(p1);

    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_split_add_relu_concat_same_axis)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
    migraphx::program p1;
    {
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
        auto* mm1  = p1.get_main_module();
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
        auto input = mm1->add_parameter("input", s);
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto concat =
            mm1->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), relu1, relu2);
1177
        mm1->add_instruction(pass_op{}, concat);
1178
1179
1180
1181
1182
    }
    run_pass(p1);

    migraphx::program p2;
    {
1183
        auto* mm2    = p2.get_main_module();
1184
        auto b       = migraphx::op::broadcast{1, {3, 2, 4}};
1185
1186
1187
        auto input   = mm2->add_parameter("input", s);
        auto one     = mm2->add_literal(1);
        auto two     = mm2->add_literal(2);
1188
        auto concat  = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
1189
        auto concatb = mm2->add_instruction(b, concat);
1190
1191
        auto sum     = mm2->add_instruction(migraphx::make_op("add"), input, concatb);
        auto relu    = mm2->add_instruction(migraphx::make_op("relu"), sum);
1192
        mm2->add_instruction(pass_op{}, relu);
1193
1194
1195
1196
1197
1198
1199
1200
1201
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_split_add_relu_multi_axes)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 6}};
    migraphx::program p1;
    {
1202
        auto* mm1  = p1.get_main_module();
1203
        auto b     = migraphx::op::broadcast{1, {3, 1, 4, 3}};
1204
        auto input = mm1->add_parameter("input", s);
1205
1206
1207
1208
1209
1210
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1, 3}}, {"starts", {0, 0}}, {"ends", {1, 3}}}),
            input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1, 3}}, {"starts", {1, 3}}, {"ends", {2, 6}}}),
            input);
1211
1212
1213
1214
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
1215
1216
1217
1218
1219
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto add   = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
1220
        mm1->add_instruction(pass_op{}, add);
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
    }
    migraphx::program p2 = p1;
    run_pass(p1);
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_split_add_relu_used_multiple_split1)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
    migraphx::program p1;
    {
1232
        auto* mm1  = p1.get_main_module();
1233
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
1234
        auto input = mm1->add_parameter("input", s);
1235
1236
1237
1238
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
1239
1240
1241
1242
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
1243
1244
1245
1246
1247
1248
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto add1  = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
        auto add2  = mm1->add_instruction(migraphx::make_op("add"), x, add1);
1249
        mm1->add_instruction(pass_op{}, add2);
1250
1251
1252
1253
1254
    }
    run_pass(p1);

    migraphx::program p2;
    {
1255
1256
1257
1258
1259
        auto* mm2  = p2.get_main_module();
        auto b     = migraphx::op::broadcast{1, {3, 2, 4}};
        auto input = mm2->add_parameter("input", s);
        auto slice = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
1260
1261
        auto one     = mm2->add_literal(1);
        auto two     = mm2->add_literal(2);
1262
        auto concat  = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
1263
        auto concatb = mm2->add_instruction(b, concat);
1264
1265
1266
1267
1268
1269
1270
1271
        auto sum     = mm2->add_instruction(migraphx::make_op("add"), input, concatb);
        auto relu    = mm2->add_instruction(migraphx::make_op("relu"), sum);
        auto x       = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), relu);
        auto y = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), relu);
        auto add1 = mm2->add_instruction(migraphx::make_op("add"), x, y);
        auto add2 = mm2->add_instruction(migraphx::make_op("add"), slice, add1);
1272
        mm2->add_instruction(pass_op{}, add2);
1273
1274
1275
1276
1277
1278
1279
1280
1281
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_split_add_relu_used_multiple_split2)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
    migraphx::program p1;
    {
1282
        auto* mm1  = p1.get_main_module();
1283
        auto b     = migraphx::op::broadcast{1, {3, 1, 4}};
1284
        auto input = mm1->add_parameter("input", s);
1285
1286
1287
1288
1289
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
        auto z     = mm1->add_instruction(migraphx::make_op("relu"), x);
1290
1291
1292
1293
        auto one   = mm1->add_literal(1);
        auto oneb  = mm1->add_instruction(b, one);
        auto two   = mm1->add_literal(2);
        auto twob  = mm1->add_instruction(b, two);
1294
1295
1296
1297
1298
1299
        auto sum1  = mm1->add_instruction(migraphx::make_op("add"), x, oneb);
        auto relu1 = mm1->add_instruction(migraphx::make_op("relu"), sum1);
        auto sum2  = mm1->add_instruction(migraphx::make_op("add"), y, twob);
        auto relu2 = mm1->add_instruction(migraphx::make_op("relu"), sum2);
        auto add1  = mm1->add_instruction(migraphx::make_op("add"), relu1, relu2);
        auto add2  = mm1->add_instruction(migraphx::make_op("add"), z, add1);
1300
        mm1->add_instruction(pass_op{}, add2);
1301
1302
1303
1304
1305
    }
    run_pass(p1);

    migraphx::program p2;
    {
1306
1307
1308
1309
1310
1311
        auto* mm2  = p2.get_main_module();
        auto b     = migraphx::op::broadcast{1, {3, 2, 4}};
        auto input = mm2->add_parameter("input", s);
        auto slice = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto z       = mm2->add_instruction(migraphx::make_op("relu"), slice);
1312
1313
        auto one     = mm2->add_literal(1);
        auto two     = mm2->add_literal(2);
1314
        auto concat  = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
1315
        auto concatb = mm2->add_instruction(b, concat);
1316
1317
1318
1319
1320
1321
1322
1323
        auto sum     = mm2->add_instruction(migraphx::make_op("add"), input, concatb);
        auto relu    = mm2->add_instruction(migraphx::make_op("relu"), sum);
        auto x       = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), relu);
        auto y = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), relu);
        auto add1 = mm2->add_instruction(migraphx::make_op("add"), x, y);
        auto add2 = mm2->add_instruction(migraphx::make_op("add"), z, add1);
1324
        mm2->add_instruction(pass_op{}, add2);
1325
1326
1327
1328
1329
1330
1331
1332
1333
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_split_between_add)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
    migraphx::program p1;
    {
1334
1335
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
1336
1337
1338
1339
1340
        auto x     = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
        auto y = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
        auto sum = mm1->add_instruction(migraphx::make_op("add"), x, y);
1341
        mm1->add_instruction(pass_op{}, sum);
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
    }
    migraphx::program p2 = p1;
    run_pass(p1);
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_dot_horiz)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
    migraphx::program p1;
    {
1353
1354
1355
1356
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto a     = mm1->add_literal(migraphx::generate_literal(s, 0));
        auto b     = mm1->add_literal(migraphx::generate_literal(s, 1));
1357
1358
1359
        auto x     = mm1->add_instruction(migraphx::make_op("dot"), input, a);
        auto y     = mm1->add_instruction(migraphx::make_op("dot"), input, b);
        auto sum   = mm1->add_instruction(migraphx::make_op("add"), x, y);
1360
        mm1->add_instruction(pass_op{}, sum);
1361
1362
1363
1364
1365
    }
    run_pass(p1);

    migraphx::program p2;
    {
1366
1367
1368
1369
        auto* mm2   = p2.get_main_module();
        auto input  = mm2->add_parameter("input", s);
        auto a      = mm2->add_literal(migraphx::generate_literal(s, 0));
        auto b      = mm2->add_literal(migraphx::generate_literal(s, 1));
1370
1371
1372
1373
1374
1375
1376
        auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, b);
        auto dot    = mm2->add_instruction(migraphx::make_op("dot"), input, concat);
        auto x      = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot);
        auto y = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), dot);
        auto sum = mm2->add_instruction(migraphx::make_op("add"), x, y);
1377
        mm2->add_instruction(pass_op{}, sum);
1378
1379
1380
1381
1382
1383
1384
1385
1386
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_dot_horiz_same_constant)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
    migraphx::program p1;
    {
1387
1388
1389
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto a     = mm1->add_literal(migraphx::generate_literal(s, 0));
1390
1391
1392
        auto x     = mm1->add_instruction(migraphx::make_op("dot"), input, a);
        auto y     = mm1->add_instruction(migraphx::make_op("dot"), input, a);
        auto sum   = mm1->add_instruction(migraphx::make_op("add"), x, y);
1393
        mm1->add_instruction(pass_op{}, sum);
1394
1395
1396
1397
1398
    }
    run_pass(p1);

    migraphx::program p2;
    {
1399
1400
1401
        auto* mm2   = p2.get_main_module();
        auto input  = mm2->add_parameter("input", s);
        auto a      = mm2->add_literal(migraphx::generate_literal(s, 0));
1402
1403
1404
1405
1406
1407
1408
        auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, a);
        auto dot    = mm2->add_instruction(migraphx::make_op("dot"), input, concat);
        auto x      = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot);
        auto y = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), dot);
        auto sum = mm2->add_instruction(migraphx::make_op("add"), x, y);
1409
        mm2->add_instruction(pass_op{}, sum);
1410
1411
1412
1413
1414
1415
1416
1417
1418
    }
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_dot_horiz_flipped)
{
    auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
    migraphx::program p1;
    {
1419
1420
1421
1422
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto a     = mm1->add_literal(migraphx::generate_literal(s, 0));
        auto b     = mm1->add_literal(migraphx::generate_literal(s, 1));
1423
1424
1425
        auto x     = mm1->add_instruction(migraphx::make_op("dot"), input, a);
        auto y     = mm1->add_instruction(migraphx::make_op("dot"), b, input);
        auto sum   = mm1->add_instruction(migraphx::make_op("add"), x, y);
1426
        mm1->add_instruction(pass_op{}, sum);
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
    }

    migraphx::program p2 = p1;
    run_pass(p1);
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_conv_horiz)
{
    auto s  = migraphx::shape{migraphx::shape::int32_type, {8, 3, 64, 64}};
    auto ws = migraphx::shape{migraphx::shape::int32_type, {12, 3, 3, 3}};
    migraphx::program p1;
    {
1440
1441
1442
1443
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto a     = mm1->add_literal(migraphx::generate_literal(ws, 0));
        auto b     = mm1->add_literal(migraphx::generate_literal(ws, 1));
1444
1445
1446
        auto x     = mm1->add_instruction(migraphx::make_op("convolution"), input, a);
        auto y     = mm1->add_instruction(migraphx::make_op("convolution"), input, b);
        auto sum   = mm1->add_instruction(migraphx::make_op("add"), x, y);
1447
        mm1->add_instruction(pass_op{}, sum);
1448
1449
1450
1451
1452
    }
    run_pass(p1);

    migraphx::program p2;
    {
1453
1454
1455
1456
        auto* mm2   = p2.get_main_module();
        auto input  = mm2->add_parameter("input", s);
        auto a      = mm2->add_literal(migraphx::generate_literal(ws, 0));
        auto b      = mm2->add_literal(migraphx::generate_literal(ws, 1));
1457
1458
1459
1460
1461
1462
1463
        auto concat = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), a, b);
        auto conv   = mm2->add_instruction(migraphx::make_op("convolution"), input, concat);
        auto x      = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}), conv);
        auto y = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}), conv);
        auto sum = mm2->add_instruction(migraphx::make_op("add"), x, y);
1464
        mm2->add_instruction(pass_op{}, sum);
1465
1466
1467
1468
    }
    EXPECT(p1.sort() == p2.sort());
}

1469
1470
1471
1472
1473
1474
TEST_CASE(simplify_group_conv_horiz)
{
    auto s  = migraphx::shape{migraphx::shape::int32_type, {1, 32, 111, 111}};
    auto ws = migraphx::shape{migraphx::shape::int32_type, {32, 1, 7, 7}};
    migraphx::program p1;
    {
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
        auto* mm1  = p1.get_main_module();
        auto x     = mm1->add_parameter("x", s);
        auto w1    = mm1->add_literal(migraphx::generate_literal(ws, 1));
        auto w2    = mm1->add_literal(migraphx::generate_literal(ws, 2));
        auto conv1 = mm1->add_instruction(
            migraphx::make_op(
                "convolution",
                {{"padding", {3, 3}}, {"stride", {2, 2}}, {"dilation", {1, 1}}, {"group", 32}}),
            x,
            w1);
        auto conv2 = mm1->add_instruction(
            migraphx::make_op(
                "convolution",
                {{"padding", {3, 3}}, {"stride", {2, 2}}, {"dilation", {1, 1}}, {"group", 32}}),
            x,
            w2);
1491
        mm1->add_instruction(pass_op{}, conv1, conv2);
1492
1493
1494
1495
1496
1497
1498
1499
    }
    migraphx::program p2 = p1;
    run_pass(p1);

    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(simplify_conv_horiz_grouped)
1500
1501
1502
1503
1504
1505
{
    auto s   = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
    auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}};
    auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
    migraphx::program p1;
    {
1506
1507
1508
1509
1510
1511
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto a     = mm1->add_literal(migraphx::generate_literal(ws1, 0));
        auto b     = mm1->add_literal(migraphx::generate_literal(ws1, 1));
        auto c     = mm1->add_literal(migraphx::generate_literal(ws2, 2));
        auto d     = mm1->add_literal(migraphx::generate_literal(ws2, 3));
1512
1513
1514
1515
1516
1517
1518
1519
1520
        auto convx =
            mm1->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, a);
        auto convy =
            mm1->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, b);
        auto dotx = mm1->add_instruction(migraphx::make_op("dot"), input, c);
        auto doty = mm1->add_instruction(migraphx::make_op("dot"), input, d);
        auto sum1 = mm1->add_instruction(migraphx::make_op("add"), convx, convy);
        auto sum2 = mm1->add_instruction(migraphx::make_op("add"), dotx, doty);
        auto sum3 = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
1521
1522

        mm1->add_instruction(pass_op{}, sum3);
1523
1524
1525
1526
1527
    }
    run_pass(p1);

    migraphx::program p2;
    {
1528
1529
1530
1531
1532
1533
        auto* mm2    = p2.get_main_module();
        auto input   = mm2->add_parameter("input", s);
        auto a       = mm2->add_literal(migraphx::generate_literal(ws1, 0));
        auto b       = mm2->add_literal(migraphx::generate_literal(ws1, 1));
        auto c       = mm2->add_literal(migraphx::generate_literal(ws2, 2));
        auto d       = mm2->add_literal(migraphx::generate_literal(ws2, 3));
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
        auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), a, b);
        auto concat2 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), c, d);
        auto conv    = mm2->add_instruction(
            migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, concat1);
        auto convx = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv);
        auto convy = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv);
        auto sum1 = mm2->add_instruction(migraphx::make_op("add"), convx, convy);
        auto dot  = mm2->add_instruction(migraphx::make_op("dot"), input, concat2);
        auto dotx = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot);
        auto doty = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot);
        auto sum2 = mm2->add_instruction(migraphx::make_op("add"), dotx, doty);
        auto sum3 = mm2->add_instruction(migraphx::make_op("add"), sum1, sum2);
1550
        mm2->add_instruction(pass_op{}, sum3);
1551
1552
1553
1554
    }
    EXPECT(p1.sort() == p2.sort());
}

1555
TEST_CASE(simplify_conv_horiz_grouped_extra1)
1556
1557
1558
1559
1560
1561
{
    auto s   = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
    auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}};
    auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
    migraphx::program p1;
    {
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto a     = mm1->add_literal(migraphx::generate_literal(ws1, 0));
        auto b     = mm1->add_literal(migraphx::generate_literal(ws1, 1));
        auto c     = mm1->add_literal(migraphx::generate_literal(ws2, 2));
        auto d     = mm1->add_literal(migraphx::generate_literal(ws2, 3));
        auto e     = mm1->add_literal(migraphx::generate_literal(s, 4));
        auto convx =
            mm1->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, a);
        auto convy =
            mm1->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, b);
        auto dotx    = mm1->add_instruction(migraphx::make_op("dot"), input, c);
        auto doty    = mm1->add_instruction(migraphx::make_op("dot"), input, d);
        auto sqdiffx = mm1->add_instruction(migraphx::make_op("sqdiff"), input, e);
        auto sum1    = mm1->add_instruction(migraphx::make_op("add"), convx, convy);
        auto sum2    = mm1->add_instruction(migraphx::make_op("add"), dotx, doty);
1578
        auto sum3    = sqdiffx;
1579
1580
        auto sum4    = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
        auto sum5    = mm1->add_instruction(migraphx::make_op("add"), sum4, sum3);
1581
        mm1->add_instruction(pass_op{}, sum5);
1582
1583
1584
1585
1586
    }
    run_pass(p1);

    migraphx::program p2;
    {
1587
1588
1589
1590
1591
1592
1593
        auto* mm2    = p2.get_main_module();
        auto input   = mm2->add_parameter("input", s);
        auto a       = mm2->add_literal(migraphx::generate_literal(ws1, 0));
        auto b       = mm2->add_literal(migraphx::generate_literal(ws1, 1));
        auto c       = mm2->add_literal(migraphx::generate_literal(ws2, 2));
        auto d       = mm2->add_literal(migraphx::generate_literal(ws2, 3));
        auto e       = mm2->add_literal(migraphx::generate_literal(s, 4));
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
        auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), a, b);
        auto concat2 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), c, d);
        auto conv    = mm2->add_instruction(
            migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, concat1);
        auto convx = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv);
        auto convy = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv);
        auto sum1 = mm2->add_instruction(migraphx::make_op("add"), convx, convy);
        auto dot  = mm2->add_instruction(migraphx::make_op("dot"), input, concat2);
        auto dotx = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot);
        auto doty = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot);
        auto sum2    = mm2->add_instruction(migraphx::make_op("add"), dotx, doty);
        auto sqdiffx = mm2->add_instruction(migraphx::make_op("sqdiff"), input, e);
1610
        auto sum3    = sqdiffx;
1611
1612
        auto sum4    = mm2->add_instruction(migraphx::make_op("add"), sum1, sum2);
        auto sum5    = mm2->add_instruction(migraphx::make_op("add"), sum4, sum3);
1613
        mm2->add_instruction(pass_op{}, sum5);
1614
1615
1616
1617
    }
    EXPECT(p1.sort() == p2.sort());
}

1618
TEST_CASE(simplify_conv_horiz_grouped_extra2)
1619
1620
1621
1622
1623
1624
{
    auto s   = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
    auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}};
    auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
    migraphx::program p1;
    {
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
        auto* mm1  = p1.get_main_module();
        auto input = mm1->add_parameter("input", s);
        auto a     = mm1->add_literal(migraphx::generate_literal(ws1, 0));
        auto b     = mm1->add_literal(migraphx::generate_literal(ws1, 1));
        auto c     = mm1->add_literal(migraphx::generate_literal(ws2, 2));
        auto d     = mm1->add_literal(migraphx::generate_literal(ws2, 3));
        auto e     = mm1->add_literal(migraphx::generate_literal(s, 4));
        auto f     = mm1->add_literal(migraphx::generate_literal(s, 5));
        auto convx =
            mm1->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, a);
        auto convy =
            mm1->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, b);
        auto dotx    = mm1->add_instruction(migraphx::make_op("dot"), input, c);
        auto doty    = mm1->add_instruction(migraphx::make_op("dot"), input, d);
        auto sqdiffx = mm1->add_instruction(migraphx::make_op("sqdiff"), input, e);
        auto sqdiffy = mm1->add_instruction(migraphx::make_op("sqdiff"), input, f);
        auto sum1    = mm1->add_instruction(migraphx::make_op("add"), convx, convy);
        auto sum2    = mm1->add_instruction(migraphx::make_op("add"), dotx, doty);
        auto sum3    = mm1->add_instruction(migraphx::make_op("add"), sqdiffx, sqdiffy);
        auto sum4    = mm1->add_instruction(migraphx::make_op("add"), sum1, sum2);
        auto sum5    = mm1->add_instruction(migraphx::make_op("add"), sum4, sum3);
1646
        mm1->add_instruction(pass_op{}, sum5);
1647
1648
1649
1650
1651
    }
    run_pass(p1);

    migraphx::program p2;
    {
1652
1653
1654
1655
1656
1657
1658
1659
        auto* mm2    = p2.get_main_module();
        auto input   = mm2->add_parameter("input", s);
        auto a       = mm2->add_literal(migraphx::generate_literal(ws1, 0));
        auto b       = mm2->add_literal(migraphx::generate_literal(ws1, 1));
        auto c       = mm2->add_literal(migraphx::generate_literal(ws2, 2));
        auto d       = mm2->add_literal(migraphx::generate_literal(ws2, 3));
        auto e       = mm2->add_literal(migraphx::generate_literal(s, 4));
        auto f       = mm2->add_literal(migraphx::generate_literal(s, 5));
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
        auto concat1 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), a, b);
        auto concat2 = mm2->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), c, d);
        auto conv    = mm2->add_instruction(
            migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, concat1);
        auto convx = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv);
        auto convy = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv);
        auto sum1 = mm2->add_instruction(migraphx::make_op("add"), convx, convy);
        auto dot  = mm2->add_instruction(migraphx::make_op("dot"), input, concat2);
        auto dotx = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot);
        auto doty = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot);
        auto sum2    = mm2->add_instruction(migraphx::make_op("add"), dotx, doty);
        auto sqdiffx = mm2->add_instruction(migraphx::make_op("sqdiff"), input, e);
        auto sqdiffy = mm2->add_instruction(migraphx::make_op("sqdiff"), input, f);
        auto sum3    = mm2->add_instruction(migraphx::make_op("add"), sqdiffx, sqdiffy);
        auto sum4    = mm2->add_instruction(migraphx::make_op("add"), sum1, sum2);
        auto sum5    = mm2->add_instruction(migraphx::make_op("add"), sum4, sum3);
1680
        mm2->add_instruction(pass_op{}, sum5);
1681
1682
1683
1684
    }
    EXPECT(p1.sort() == p2.sort());
}

1685
1686
1687
1688
TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
{
    migraphx::program p1;
    {
1689
1690
1691
        auto* mm1 = p1.get_main_module();
        auto x    = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
        auto w    = mm1->add_literal(
1692
            migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
1693
1694
1695
        auto conv   = mm1->add_instruction(migraphx::make_op("convolution"), x, w);
        auto slice1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv);
1696
        auto a1 =
1697
            mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
1698
1699
1700
        auto b1 = mm1->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a1);
        auto mul = mm1->add_instruction(migraphx::make_op("mul"), slice1, b1);
1701
        auto a2 =
1702
            mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2));
1703
1704
1705
        auto b2 = mm1->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a2);
        auto add1 = mm1->add_instruction(migraphx::make_op("add"), mul, b2);
1706
        auto a3 =
1707
            mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
1708
1709
1710
1711
1712
        auto b3 = mm1->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 384, 17, 17}}}), a3);
        auto slice2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv);
        auto add2 = mm1->add_instruction(migraphx::make_op("add"), slice2, b3);
1713
        mm1->add_instruction(pass_op{}, add1, add2);
1714
1715
1716
1717
1718
    }
    run_pass(p1);

    migraphx::program p2;
    {
1719
1720
1721
        auto* mm2 = p2.get_main_module();
        auto x    = mm2->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
        auto w    = mm2->add_literal(
1722
            migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
1723
1724
        auto wslice1 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {384}}}), w);
1725
        auto a1 =
1726
            mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
1727
1728
1729
1730
1731
1732
1733
1734
        auto b1 = mm2->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 0}, {"dims", {384, 1024, 1, 1}}}), a1);
        auto mul     = mm2->add_instruction(migraphx::make_op("mul"), b1, wslice1);
        auto wslice2 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w);
        auto concat1 =
            mm2->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), mul, wslice2);
        auto conv = mm2->add_instruction(migraphx::make_op("convolution"), x, concat1);
1735
        auto a2 =
1736
            mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2));
1737
        auto a3 =
1738
            mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
1739
1740
1741
1742
1743
1744
1745
1746
        auto concat2 = mm2->add_instruction(migraphx::make_op("concat"), a2, a3);
        auto b4      = mm2->add_instruction(
            migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 768, 17, 17}}}), concat2);
        auto add    = mm2->add_instruction(migraphx::make_op("add"), conv, b4);
        auto slice1 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), add);
        auto slice2 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), add);
1747
        mm2->add_instruction(pass_op{}, slice1, slice2);
1748
1749
1750
    }
    EXPECT(p1.sort() == p2.sort());
}
1751
1752
1753
1754
1755
1756
TEST_CASE(reorder_reshape_slice)
{
    std::vector<int64_t> perm0 = {0, 2, 1, 3};
    std::vector<int64_t> perm1 = {0, 2, 3, 1};
    auto create_p1             = [&](std::size_t batch_size) {
        migraphx::program p1;
1757
        auto* mm1  = p1.get_main_module();
1758
        auto s     = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
1759
        auto input = mm1->add_parameter("input", s);
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
        auto slc0  = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input);
        auto slc1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {640}}, {"ends", {1280}}}),
            input);
        auto slc2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
            input);

        auto c0 = mm1->add_instruction(migraphx::make_op("contiguous"), slc0);
        auto c1 = mm1->add_instruction(migraphx::make_op("contiguous"), slc1);
        auto c2 = mm1->add_instruction(migraphx::make_op("contiguous"), slc2);
1772
1773

        std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 128, 10, 64};
1774
1775
1776
        auto r0 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
        auto r1 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
        auto r2 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
1777

1778
1779
1780
        auto t0 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r0);
        auto t1 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r1);
        auto t2 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), r2);
1781

1782
1783
        auto sum = mm1->add_instruction(migraphx::make_op("add"), t0, t1);
        auto ret = mm1->add_instruction(migraphx::make_op("dot"), sum, t2);
1784
        mm1->add_return({ret});
1785
1786
1787
1788
1789
1790

        return p1;
    };

    auto create_p2 = [&](std::size_t batch_size) {
        migraphx::program p2;
1791
        auto* mm2  = p2.get_main_module();
1792
        auto s     = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
1793
        auto input = mm2->add_parameter("input", s);
1794
        std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 128, 30, 64};
1795
        auto r = mm2->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
1796

1797
1798
1799
1800
1801
1802
        auto slc0 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {10}}}), r);
        auto slc1 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {10}}, {"ends", {20}}}), r);
        auto slc2 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {20}}, {"ends", {30}}}), r);
1803

1804
1805
1806
        auto t0 = mm2->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc0);
        auto t1 = mm2->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc1);
        auto t2 = mm2->add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), slc2);
1807

1808
1809
        auto sum = mm2->add_instruction(migraphx::make_op("add"), t0, t1);
        auto ret = mm2->add_instruction(migraphx::make_op("dot"), sum, t2);
1810
        mm2->add_return({ret});
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826

        return p2;
    };

    auto test = [&](std::size_t batch_size) {
        auto p1 = create_p1(batch_size);
        run_pass(p1);
        auto p2 = create_p2(batch_size);
        EXPECT(p1.sort() == p2.sort());
    };

    test(1);
    test(4);
    test(8);
}

1827
TEST_CASE(reorder_reshape_slice_move_axis1)
1828
1829
1830
{
    auto create_p1 = [](std::size_t batch_size) {
        migraphx::program p1;
1831
1832
        auto* mm1 = p1.get_main_module();
        auto s    = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}};
1833
1834
        std::vector<int64_t> perm0 = {0, 2, 1, 3};
        std::vector<int64_t> perm1 = {0, 2, 3, 1};
1835
        auto input                 = mm1->add_parameter("input", s);
1836
1837
1838
1839
1840
1841
        auto slc0                  = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), input);
        auto slc1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), input);
        auto slc2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {96}}}), input);
1842

1843
1844
1845
        auto c0 = mm1->add_instruction(migraphx::make_op("contiguous"), slc0);
        auto c1 = mm1->add_instruction(migraphx::make_op("contiguous"), slc1);
        auto c2 = mm1->add_instruction(migraphx::make_op("contiguous"), slc2);
1846

1847
        std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 64, 4, 32};
1848
1849
1850
        auto r0 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
        auto r1 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
        auto r2 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
1851

1852
1853
1854
        auto t0 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r0);
        auto t1 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), r1);
        auto t2 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), r2);
1855

1856
1857
        auto sum = mm1->add_instruction(migraphx::make_op("add"), t0, t1);
        auto ret = mm1->add_instruction(migraphx::make_op("dot"), sum, t2);
1858
        mm1->add_return({ret});
1859
1860
1861
1862

        return p1;
    };

1863
1864
    auto create_p2 = [](std::size_t batch_size) {
        migraphx::program p;
1865
1866
        auto* mm = p.get_main_module();
        auto s   = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}};
1867
1868
        std::vector<int64_t> perm0 = {0, 2, 1, 3};
        std::vector<int64_t> perm1 = {0, 2, 3, 1};
1869
        auto input                 = mm->add_parameter("input", s);
1870
        std::vector<int64_t> lens  = {static_cast<int64_t>(batch_size), 64, 4, 96};
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
        auto rsp  = mm->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
        auto slc0 = mm->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp);
        auto t0   = mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc0);
        auto slc1 = mm->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp);
        auto t1   = mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc1);
        auto slc2 = mm->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp);
        auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), slc2);

        auto sum = mm->add_instruction(migraphx::make_op("add"), t0, t1);
        auto ret = mm->add_instruction(migraphx::make_op("dot"), sum, t2);
1884
        mm->add_return({ret});
1885
1886
1887
1888

        return p;
    };

1889
1890
    auto test = [&](std::size_t batch_size) {
        auto p1 = create_p1(batch_size);
1891
        auto p2 = create_p2(batch_size);
1892
1893
1894
1895
1896
1897
1898
1899
        run_pass(p1);
        EXPECT(p1.sort() == p2.sort());
    };

    test(4);
    test(8);
}

1900
1901
1902
1903
TEST_CASE(reorder_reshape_slice_move_axis2)
{
    auto create_p1 = [] {
        migraphx::program p1;
1904
        auto* mm1 = p1.get_main_module();
1905
        migraphx::shape s{migraphx::shape::float_type, {128, 96}};
1906
        auto input = mm1->add_parameter("input", s);
1907
1908
1909
1910
1911
1912
        auto slc0  = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {32}}}), input);
        auto slc1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {32}}, {"ends", {64}}}), input);
        auto slc2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {96}}}), input);
1913

1914
1915
1916
        auto c0 = mm1->add_instruction(migraphx::make_op("contiguous"), slc0);
        auto c1 = mm1->add_instruction(migraphx::make_op("contiguous"), slc1);
        auto c2 = mm1->add_instruction(migraphx::make_op("contiguous"), slc2);
1917
1918

        std::vector<int64_t> lens = {1, 16, 8, 32};
1919
1920
1921
        auto r0 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
        auto r1 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
        auto r2 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
1922

1923
1924
        auto sum = mm1->add_instruction(migraphx::make_op("add"), r0, r1);
        auto ret = mm1->add_instruction(migraphx::make_op("mul"), sum, r2);
1925
        mm1->add_return({ret});
1926
1927
1928
1929
1930
1931

        return p1;
    };

    auto create_p2 = [] {
        migraphx::program p;
1932
        auto* mm                  = p.get_main_module();
1933
        auto s                    = migraphx::shape{migraphx::shape::float_type, {128, 96}};
1934
        auto input                = mm->add_parameter("input", s);
1935
        std::vector<int64_t> lens = {1, 16, 8, 96};
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
        auto rsp  = mm->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
        auto slc0 = mm->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp);
        auto slc1 = mm->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp);
        auto slc2 = mm->add_instruction(
            migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp);

        auto sum = mm->add_instruction(migraphx::make_op("add"), slc0, slc1);
        auto ret = mm->add_instruction(migraphx::make_op("mul"), sum, slc2);
1946
        mm->add_return({ret});
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960

        return p;
    };

    auto p1 = create_p1();
    auto p2 = create_p2();
    run_pass(p1);
    EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(reorder_reshape_slice_not_apply)
{
    auto create_p = [] {
        migraphx::program p;
1961
        auto* mm = p.get_main_module();
1962
        migraphx::shape s{migraphx::shape::float_type, {128, 96}};
1963
        auto input = mm->add_parameter("input", s);
1964
1965
1966
1967
1968
1969
        auto slc0  = mm->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {32}}}), input);
        auto slc1 = mm->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {32}}, {"ends", {64}}}), input);
        auto slc2 = mm->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {96}}}), input);
1970

1971
1972
1973
        auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), slc0);
        auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), slc1);
        auto c2 = mm->add_instruction(migraphx::make_op("contiguous"), slc2);
1974
1975

        std::vector<int64_t> lens = {1, 16, 16, 16};
1976
1977
1978
        auto r0 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
        auto r1 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
        auto r2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
1979

1980
1981
        auto sum = mm->add_instruction(migraphx::make_op("add"), r0, r1);
        auto ret = mm->add_instruction(migraphx::make_op("mul"), sum, r2);
1982
        mm->add_return({ret});
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992

        return p;
    };

    auto p1 = create_p();
    auto p2 = p1;
    run_pass(p1);
    EXPECT(p1.sort() == p2.sort());
}

1993
1994
1995
1996
TEST_CASE(reorder_reshape_slice_diff_dims)
{
    auto create_p1 = [](std::size_t batch_size) {
        migraphx::program p1;
1997
1998
        auto* mm1 = p1.get_main_module();
        auto s    = migraphx::shape{migraphx::shape::float_type, {batch_size, 96, 96}};
1999
2000
        std::vector<int64_t> perm0 = {0, 2, 1, 3};
        std::vector<int64_t> perm1 = {0, 2, 3, 1};
2001
        auto input                 = mm1->add_parameter("input", s);
2002
2003
2004
2005
2006
2007
        auto slc0                  = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), input);
        auto slc1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), input);
        auto slc2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {96}}}), input);
2008

2009
2010
2011
        auto c0 = mm1->add_instruction(migraphx::make_op("contiguous"), slc0);
        auto c1 = mm1->add_instruction(migraphx::make_op("contiguous"), slc1);
        auto c2 = mm1->add_instruction(migraphx::make_op("contiguous"), slc2);
2012
2013
2014

        std::vector<int64_t> lens  = {static_cast<int64_t>(batch_size), 32, 3, 32};
        std::vector<int64_t> lens1 = {static_cast<int64_t>(batch_size), 48, 2, 32};
2015
2016
2017
        auto r0 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
        auto r1 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
        auto r2 = mm1->add_instruction(migraphx::make_op("reshape", {{"dims", lens1}}), c2);
2018

2019
        mm1->add_return({r0, r1, r2});
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039

        return p1;
    };

    auto test = [&](std::size_t batch_size) {
        auto p1 = create_p1(batch_size);
        auto p2 = p1;
        run_pass(p1);
        EXPECT(p1.sort() == p2.sort());
    };

    test(4);
    test(8);
}

TEST_CASE(reorder_slice_trans)
{
    std::vector<int64_t> perm = {0, 2, 1};
    auto create_p1            = [&](std::size_t batch_size) {
        migraphx::program p1;
2040
        auto* mm1  = p1.get_main_module();
2041
        auto s     = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
2042
        auto input = mm1->add_parameter("input", s);
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
        auto slc0  = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input);
        auto slc1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {640}}, {"ends", {1280}}}),
            input);
        auto slc2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
            input);

        auto t0 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), slc0);
        auto t1 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), slc1);
        auto t2 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), slc2);

        auto sum = mm1->add_instruction(migraphx::make_op("add"), t0, t1);
        auto ret = mm1->add_instruction(migraphx::make_op("mul"), sum, t2);
2058
        mm1->add_return({ret});
2059
2060
2061
2062
2063
2064

        return p1;
    };

    auto create_p2 = [&](std::size_t batch_size) {
        migraphx::program p2;
2065
        auto* mm2  = p2.get_main_module();
2066
        auto s     = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
2067
        auto input = mm2->add_parameter("input", s);
2068
        auto r     = mm2->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), input);
2069

2070
2071
2072
2073
2074
2075
        auto slc0 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {640}}}), r);
        auto slc1 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {640}}, {"ends", {1280}}}), r);
        auto slc2 = mm2->add_instruction(
            migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1280}}, {"ends", {1920}}}), r);
2076

2077
2078
        auto sum = mm2->add_instruction(migraphx::make_op("add"), slc0, slc1);
        auto ret = mm2->add_instruction(migraphx::make_op("mul"), sum, slc2);
2079
        mm2->add_return({ret});
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098

        return p2;
    };

    auto test = [&](std::size_t batch_size) {
        auto p1 = create_p1(batch_size);
        run_pass(p1);
        auto p2 = create_p2(batch_size);
        EXPECT(p1.sort() == p2.sort());
    };

    test(1);
    test(8);
}

TEST_CASE(reorder_slice_trans_diff_perm)
{
    auto create_p1 = [](std::size_t batch_size) {
        migraphx::program p1;
2099
2100
        auto* mm1 = p1.get_main_module();
        auto s    = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
2101
2102
        std::vector<int64_t> perm0 = {0, 2, 1};
        std::vector<int64_t> perm1 = {0, 1, 2};
2103
        auto input                 = mm1->add_parameter("input", s);
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
        auto slc0                  = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input);
        auto slc1 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {640}}, {"ends", {1280}}}),
            input);
        auto slc2 = mm1->add_instruction(
            migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
            input);

        auto t0 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc0);
        auto t1 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm0}}), slc1);
        auto t2 = mm1->add_instruction(migraphx::make_op("transpose", {{"dims", perm1}}), slc2);

        auto sum = mm1->add_instruction(migraphx::make_op("add"), t0, t1);
        auto ret = mm1->add_instruction(migraphx::make_op("dot"), sum, t2);
2119
        mm1->add_return({ret});
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134

        return p1;
    };

    auto test = [&](std::size_t batch_size) {
        auto p1 = create_p1(batch_size);
        run_pass(p1);
        auto p2 = p1;
        EXPECT(p1.sort() == p2.sort());
    };

    test(1);
    test(4);
}

Paul's avatar
Paul committed
2135
int main(int argc, const char* argv[]) { test::run(argc, argv); }