matcher.cpp 37 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/matcher.hpp>
#include <migraphx/iterator_for.hpp>
Paul's avatar
Paul committed
3
4
5
#include <test.hpp>
#include <basic_ops.hpp>

Paul's avatar
Paul committed
6
namespace match = migraphx::match;
Paul's avatar
Paul committed
7

Paul's avatar
Paul committed
8
9
MIGRAPHX_PRED_MATCHER(throws, migraphx::instruction_ref) { MIGRAPHX_THROW("Matcher throws"); }

Paul's avatar
Paul committed
10
template <class M>
11
migraphx::match::matcher_result find_match(migraphx::module& modl, M&& m)
Paul's avatar
Paul committed
12
{
Paul's avatar
Paul committed
13
    migraphx::match::matcher_result result;
14
    for(auto ins : migraphx::iterator_for(modl))
Paul's avatar
Paul committed
15
    {
16
17
        result = migraphx::match::match_instruction(modl, ins, m);
        if(result.result != modl.end())
Paul's avatar
Paul committed
18
19
20
21
22
23
24
            return result;
    }
    return result;
}

void match1()
{
Paul's avatar
Paul committed
25
    migraphx::program p;
26
27
28
29
30

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(1);
    auto m   = match::standard_shape();
    auto r   = find_match(*mm, m);
Paul's avatar
Paul committed
31
32
33
    EXPECT(bool{r.result == l});
}

Paul's avatar
Paul committed
34
TEST_CASE(match_name1)
Paul's avatar
Paul committed
35
{
Paul's avatar
Paul committed
36
    migraphx::program p;
37
38
39
40
41
42

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
43
    auto m = match::name("sum");
44
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
45
46
47
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
48
TEST_CASE(match_name2)
Paul's avatar
Paul committed
49
{
Paul's avatar
Paul committed
50
    migraphx::program p;
51
52
53
54
55
56

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
57
    auto m = match::name("min");
58
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
59
    EXPECT(bool{r.result == mm->end()});
Paul's avatar
Paul committed
60
61
}

Paul's avatar
Paul committed
62
TEST_CASE(match_name3)
Paul's avatar
Paul committed
63
{
Paul's avatar
Paul committed
64
    migraphx::program p;
65
66
67
68
69
70

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
71
    auto m = match::name("sum")(match::standard_shape());
72
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
73
74
75
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
76
TEST_CASE(match_arg1)
Paul's avatar
Paul committed
77
{
Paul's avatar
Paul committed
78
    migraphx::program p;
79
80
81
82
83
84

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
85
    auto m = match::name("sum")(match::arg(0)(match::name("@literal")), match::standard_shape());
86
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
87
88
89
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
90
TEST_CASE(match_arg2)
Paul's avatar
Paul committed
91
{
Paul's avatar
Paul committed
92
    migraphx::program p;
93
94
95
96
97
98

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
99
    auto m = match::name("sum")(match::arg(0)(match::name("sum")), match::standard_shape());
100
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
101
    EXPECT(bool{r.result == mm->end()});
Paul's avatar
Paul committed
102
103
}

Paul's avatar
Paul committed
104
TEST_CASE(match_arg3)
Paul's avatar
Paul committed
105
{
Paul's avatar
Paul committed
106
    migraphx::program p;
107
108
109
110
111
112

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
113
    auto m = match::name("sum")(match::arg(1)(match::name("@literal")), match::standard_shape());
114
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
115
116
117
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
118
TEST_CASE(match_arg4)
Paul's avatar
Paul committed
119
{
Paul's avatar
Paul committed
120
    migraphx::program p;
121
122
123
124
125
126

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum  = mm->add_instruction(sum_op{}, one, two);
    auto pass = mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
127
    auto m    = match::name("pass")(match::arg(0)(match::name("sum")), match::standard_shape());
128
    auto r    = find_match(*mm, m);
Paul's avatar
Paul committed
129
130
131
    EXPECT(bool{r.result == pass});
}

Paul's avatar
Paul committed
132
TEST_CASE(match_arg5)
Paul's avatar
Paul committed
133
{
Paul's avatar
Paul committed
134
    migraphx::program p;
135
136
137
138
139
140

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
141
    auto m = match::name("pass")(match::arg(1)(match::name("sum")), match::standard_shape());
142
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
143
    EXPECT(bool{r.result == mm->end()});
Paul's avatar
Paul committed
144
145
}

Paul's avatar
Paul committed
146
TEST_CASE(match_arg6)
Paul's avatar
Paul committed
147
{
Paul's avatar
Paul committed
148
    migraphx::program p;
149
150
151
152
153
154

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
155
    auto m = match::name("sum")(match::arg(0)(match::name("@literal")));
156
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
157
158
159
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
160
TEST_CASE(match_arg7)
Paul's avatar
Paul committed
161
{
Paul's avatar
Paul committed
162
    migraphx::program p;
163
164
165
166
167
168

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
169
    auto m = match::name("sum")(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
170
                                match::arg(1)(match::name("@literal")));
171
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
172
173
174
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
175
176
177
TEST_CASE(match_arg8)
{
    migraphx::program p;
178
179
180
181
182
183

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
184
    auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
185
                                              match::arg(1)(match::name("@literal"))),
Paul's avatar
Paul committed
186
                                match::standard_shape());
187
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
188
189
190
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
191
192
193
TEST_CASE(match_nargs1)
{
    migraphx::program p;
194
195
196
197
198
199

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
200
    auto m = match::name("sum")(match::nargs(2));
201
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
202
203
204
205
206
207
    EXPECT(bool{r.result == sum});
}

TEST_CASE(match_nargs2)
{
    migraphx::program p;
208
209
210
211
212
213

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
214
    auto m = match::name("sum")(match::nargs(2), match::standard_shape());
215
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
216
217
218
219
220
221
    EXPECT(bool{r.result == sum});
}

TEST_CASE(match_nargs3)
{
    migraphx::program p;
222
223
224
225
226
227

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
228
    auto m = match::name("sum")(match::all_of(match::nargs(2)));
229
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
230
231
232
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
233
TEST_CASE(match_args1)
Paul's avatar
Paul committed
234
{
Paul's avatar
Paul committed
235
    migraphx::program p;
236
237
238
239
240
241

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
242
243
    auto m = match::name("sum")(match::args(match::name("@literal"), match::name("@literal")),
                                match::standard_shape());
244
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
245
246
247
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
248
TEST_CASE(match_args2)
Paul's avatar
Paul committed
249
{
Paul's avatar
Paul committed
250
    migraphx::program p;
251
252
253
254
255
256

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
257
258
    auto m = match::name("sum")(match::args(match::name("@literal"), match::name("sum")),
                                match::standard_shape());
259
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
260
    EXPECT(bool{r.result == mm->end()});
Paul's avatar
Paul committed
261
262
}

Paul's avatar
Paul committed
263
TEST_CASE(match_args3)
Paul's avatar
Paul committed
264
{
Paul's avatar
Paul committed
265
    migraphx::program p;
266
267
268
269
270
271

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
272
    auto m = match::name("sum")(match::args(match::name("@literal")), match::standard_shape());
273
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
274
    EXPECT(bool{r.result == mm->end()});
Paul's avatar
Paul committed
275
276
}

Paul's avatar
Paul committed
277
TEST_CASE(match_args4)
Paul's avatar
Paul committed
278
{
Paul's avatar
Paul committed
279
    migraphx::program p;
280
281
282
283
284
285
286

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
287
288
    auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
                                match::standard_shape());
289
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
290
291
292
    EXPECT(bool{r.result == sum2});
}

Paul's avatar
Paul committed
293
TEST_CASE(match_args5)
Paul's avatar
Paul committed
294
{
Paul's avatar
Paul committed
295
    migraphx::program p;
296
297
298
299
300
301

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
302
303
    auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
                                match::standard_shape());
304
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
305
    EXPECT(bool{r.result == mm->end()});
Paul's avatar
Paul committed
306
307
}

Paul's avatar
Paul committed
308
TEST_CASE(match_args6)
Paul's avatar
Paul committed
309
{
Paul's avatar
Paul committed
310
    migraphx::program p;
311
312
313
314
315
316

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum  = mm->add_instruction(sum_op{}, one, two);
    auto pass = mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
317
    auto m    = match::name("pass")(match::args(match::name("sum")), match::standard_shape());
318
    auto r    = find_match(*mm, m);
Paul's avatar
Paul committed
319
320
321
    EXPECT(bool{r.result == pass});
}

Paul's avatar
Paul committed
322
TEST_CASE(match_args7)
Paul's avatar
Paul committed
323
{
Paul's avatar
Paul committed
324
    migraphx::program p;
325
326
327
328
329
330

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum  = mm->add_instruction(sum_op{}, one, two);
    auto pass = mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
331
    auto m    = match::name("pass")(match::args(match::name("sum")(match::args(
Paul's avatar
Paul committed
332
333
                                     match::name("@literal"), match::name("@literal")))),
                                 match::standard_shape());
334
    auto r    = find_match(*mm, m);
Paul's avatar
Paul committed
335
336
337
    EXPECT(bool{r.result == pass});
}

Paul's avatar
Paul committed
338
TEST_CASE(match_either_args1)
Paul's avatar
Paul committed
339
{
Paul's avatar
Paul committed
340
    migraphx::program p;
341
342
343
344
345
346
347

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
348
349
    auto m =
        match::name("sum")(match::either_arg(0, 1)(match::name("sum"), match::name("@literal")));
350
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
351
352
353
    EXPECT(bool{r.result == sum2});
}

Paul's avatar
Paul committed
354
TEST_CASE(match_either_args2)
Paul's avatar
Paul committed
355
{
Paul's avatar
Paul committed
356
    migraphx::program p;
357
358
359
360
361
362
363

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
364
365
    auto m =
        match::name("sum")(match::either_arg(0, 1)(match::name("@literal"), match::name("sum")));
366
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
367
368
369
    EXPECT(bool{r.result == sum2});
}

Paul's avatar
Paul committed
370
TEST_CASE(match_either_args3)
Paul's avatar
Paul committed
371
{
Paul's avatar
Paul committed
372
    migraphx::program p;
373
374
375
376
377
378
379

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
380
381
    auto m =
        match::name("sum")(match::either_arg(0, 1)(match::name("pass"), match::name("@literal")));
382
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
383
    EXPECT(bool{r.result == mm->end()});
Paul's avatar
Paul committed
384
385
}

Paul's avatar
Paul committed
386
387
388
TEST_CASE(match_either_args_any1)
{
    migraphx::program p;
389
390
391
392
393
394
395

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
396
397
    auto m =
        match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y")));
398
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
399
400
401
402
403
404
405
    EXPECT(bool{r.result == sum1});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

TEST_CASE(match_either_args_any2)
{
    migraphx::program p;
406
407
408
409
410
411
412

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
413
414
    auto m = match::name("sum")(
        match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y")));
415
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
416
417
418
419
420
421
422
    EXPECT(bool{r.result == sum1});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

TEST_CASE(match_either_args_any3)
{
    migraphx::program p;
423
424
425
426
427
428
429

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
430
431
    auto m = match::name("sum")(
        match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y")));
432
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
433
434
435
436
437
438
439
    EXPECT(bool{r.result == sum1});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

TEST_CASE(match_either_args_any4)
{
    migraphx::program p;
440
441
442
443
444
445
446

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
447
448
    auto m = match::name("sum")(
        match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y")));
449
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
450
451
452
453
454
455
456
    EXPECT(bool{r.result == sum2});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

TEST_CASE(match_either_args_any5)
{
    migraphx::program p;
457
458
459
460
461
462
463

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
464
465
    auto m = match::name("sum")(
        match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y")));
466
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
467
468
469
470
    EXPECT(bool{r.result == sum2});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

Paul's avatar
Paul committed
471
TEST_CASE(match_all_of1)
Paul's avatar
Paul committed
472
{
Paul's avatar
Paul committed
473
    migraphx::program p;
474
475
476
477
478
479

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
480
    auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
481
                                              match::arg(1)(match::name("@literal"))));
482
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
483
484
485
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
486
TEST_CASE(match_all_of2)
Paul's avatar
Paul committed
487
{
Paul's avatar
Paul committed
488
    migraphx::program p;
489
490
491
492
493
494

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
495
496
    auto m = match::name("sum")(
        match::all_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
497
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
498
    EXPECT(bool{r.result == mm->end()});
Paul's avatar
Paul committed
499
500
}

Paul's avatar
Paul committed
501
502
503
TEST_CASE(match_all_of3)
{
    migraphx::program p;
504
505
506
507
508
509

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
510
511
    auto m = match::name("sum")(match::all_of(match::all_of(
        match::arg(0)(match::name("@literal")), match::arg(1)(match::name("@literal")))));
512
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
513
514
515
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
516
517
518
TEST_CASE(match_lazy_any_of)
{
    migraphx::program p;
519
520
521
522

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    mm->add_instruction(pass_op{}, one);
Paul's avatar
Paul committed
523
    auto m = match::any_of(match::any(), throws());
524
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
525
526
527
528
529
530
    EXPECT(bool{r.result == one});
}

TEST_CASE(match_lazy_all_of)
{
    migraphx::program p;
531
532
533
534

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    mm->add_instruction(pass_op{}, one);
Paul's avatar
Paul committed
535
    auto m = match::all_of(match::none(), throws());
536
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
537
    EXPECT(bool{r.result == mm->end()});
Paul's avatar
Paul committed
538
539
540
541
542
}

TEST_CASE(match_lazy_none_of)
{
    migraphx::program p;
543
544
545
546

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    mm->add_instruction(pass_op{}, one);
Paul's avatar
Paul committed
547
    auto m = match::none_of(match::any(), throws());
548
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
549
    EXPECT(bool{r.result == mm->end()});
Paul's avatar
Paul committed
550
551
}

Paul's avatar
Paul committed
552
TEST_CASE(match_any_of1)
Paul's avatar
Paul committed
553
{
Paul's avatar
Paul committed
554
    migraphx::program p;
555
556
557
558
559
560

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
561
562
    auto m = match::name("sum")(
        match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
563
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
564
565
566
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
567
TEST_CASE(match_any_of2)
Paul's avatar
Paul committed
568
{
Paul's avatar
Paul committed
569
    migraphx::program p;
570
571
572
573
574
575

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
576
577
    auto m = match::name("sum")(
        match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
578
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
579
    EXPECT(bool{r.result == mm->end()});
Paul's avatar
Paul committed
580
581
}

Paul's avatar
Paul committed
582
583
584
TEST_CASE(match_any_of_lazy1)
{
    migraphx::program p;
585
586
587
588
589
590

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
591
    auto m = match::name("sum")(
Paul's avatar
Paul committed
592
593
        match::any_of(match::args(match::any(), match::any()).bind("x"),
                      match::args(match::name("sum"), match::name("sum")).bind("y")));
594
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
595
596
597
598
599
600
601
602
603
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x"));
    EXPECT(bool{r.instructions["x"] == sum});
    EXPECT(not migraphx::contains(r.instructions, "y"));
}

TEST_CASE(match_any_of_lazy2)
{
    migraphx::program p;
604
605
606
607
608
609

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
610
    auto m = match::name("sum")(
Paul's avatar
Paul committed
611
612
        match::any_of(match::args(match::name("@literal"), match::name("@literal")).bind("x"),
                      match::args(match::any(), match::any()).bind("y")));
613
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
614
615
616
617
618
619
620
621
622
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x"));
    EXPECT(bool{r.instructions["x"] == sum});
    EXPECT(not migraphx::contains(r.instructions, "y"));
}

TEST_CASE(match_any_of_lazy3)
{
    migraphx::program p;
623
624
625
626
627
628

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
629
    auto m = match::name("sum")(
Paul's avatar
Paul committed
630
631
        match::any_of(match::args(match::any(), match::any()).bind("x"),
                      match::args(match::name("@literal"), match::name("@literal")).bind("y")));
632
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
633
634
635
636
637
638
639
640
641
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x"));
    EXPECT(bool{r.instructions["x"] == sum});
    EXPECT(not migraphx::contains(r.instructions, "y"));
}

TEST_CASE(match_any_of_lazy4)
{
    migraphx::program p;
642
643
644
645
646
647

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
648
649
650
    auto m = match::name("sum")(match::any_of(
        match::args(match::name("@literal").bind("x1"), match::name("@literal").bind("y1")),
        match::args(match::any().bind("x2"), match::any().bind("y2"))));
651
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
652
653
654
655
656
657
658
659
660
661
662
663
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x1"));
    EXPECT(migraphx::contains(r.instructions, "y1"));
    EXPECT(bool{r.instructions["x1"] == one});
    EXPECT(bool{r.instructions["y1"] == two});
    EXPECT(not migraphx::contains(r.instructions, "x2"));
    EXPECT(not migraphx::contains(r.instructions, "y2"));
}

TEST_CASE(match_any_of_lazy5)
{
    migraphx::program p;
664
665
666
667
668
669

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
670
671
672
    auto m = match::name("sum")(match::any_of(
        match::args(match::any().bind("x1"), match::any().bind("y1")),
        match::args(match::name("@literal").bind("x2"), match::name("@literal").bind("y2"))));
673
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
674
675
676
677
678
679
680
681
682
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x1"));
    EXPECT(migraphx::contains(r.instructions, "y1"));
    EXPECT(bool{r.instructions["x1"] == one});
    EXPECT(bool{r.instructions["y1"] == two});
    EXPECT(not migraphx::contains(r.instructions, "x2"));
    EXPECT(not migraphx::contains(r.instructions, "y2"));
}

Paul's avatar
Paul committed
683
TEST_CASE(match_none_of1)
Paul's avatar
Paul committed
684
{
Paul's avatar
Paul committed
685
    migraphx::program p;
686
687
688
689
690
691

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
692
693
    auto m = match::name("sum")(
        match::none_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
694
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
695
696
697
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
698
TEST_CASE(match_none_of2)
Paul's avatar
Paul committed
699
{
Paul's avatar
Paul committed
700
    migraphx::program p;
701
702
703
704
705
706

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
707
    auto m = match::name("sum")(match::none_of(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
708
                                               match::arg(1)(match::name("@literal"))));
709
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
710
    EXPECT(bool{r.result == mm->end()});
Paul's avatar
Paul committed
711
712
}

Paul's avatar
Paul committed
713
714
715
TEST_CASE(match_output1)
{
    migraphx::program p;
716
717
718
719
720
721
722

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto minus = mm->add_instruction(minus_op{}, two, one);
    auto sum   = mm->add_instruction(sum_op{}, minus, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
723
    auto m = match::name("minus")(match::output(match::name("sum")));
724
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
725
726
727
728
729
730
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_output2)
{
    migraphx::program p;
731
732
733
734
735
736
737

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto minus = mm->add_instruction(minus_op{}, two, one);
    auto sum   = mm->add_instruction(sum_op{}, minus, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
738
    auto m = match::name("@literal")(match::output(match::name("sum")));
739
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
740
    EXPECT(bool{r.result == mm->end()});
Paul's avatar
Paul committed
741
742
743
744
745
}

TEST_CASE(match_skip_output1)
{
    migraphx::program p;
746
747
748
749
750
751
752

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto minus = mm->add_instruction(minus_op{}, two, one);
    auto sum   = mm->add_instruction(sum_op{}, minus, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
753
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
754
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
755
756
757
758
759
760
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output2)
{
    migraphx::program p;
761
762
763
764
765
766
767
768

    auto* mm        = p.get_main_module();
    auto one        = mm->add_literal(1);
    auto two        = mm->add_literal(2);
    auto minus      = mm->add_instruction(minus_op{}, two, one);
    auto minus_pass = mm->add_instruction(pass_op{}, minus);
    auto sum        = mm->add_instruction(sum_op{}, minus_pass, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
769
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
770
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
771
772
773
774
775
776
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output3)
{
    migraphx::program p;
777
778
779
780
781
782
783
784
785
786

    auto* mm         = p.get_main_module();
    auto one         = mm->add_literal(1);
    auto two         = mm->add_literal(2);
    auto minus       = mm->add_instruction(minus_op{}, two, one);
    auto minus_pass1 = mm->add_instruction(pass_op{}, minus);
    auto minus_pass2 = mm->add_instruction(pass_op{}, minus_pass1);
    auto minus_pass3 = mm->add_instruction(pass_op{}, minus_pass2);
    auto sum         = mm->add_instruction(sum_op{}, minus_pass3, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
787
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
788
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
789
790
791
792
793
794
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output4)
{
    migraphx::program p;
795
796
797
798
799
800
801

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto pass = mm->add_instruction(pass_op{}, one);
    auto sum  = mm->add_instruction(sum_op{}, pass, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
802
    auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
803
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
804
805
806
807
808
809
    EXPECT(bool{r.result == two});
}

TEST_CASE(match_skip_output5)
{
    migraphx::program p;
810
811
812
813
814
815
816
817
818

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto pass = mm->add_instruction(pass_op{}, one);
    auto sum1 = mm->add_instruction(sum_op{}, pass, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, one);
    auto sum3 = mm->add_instruction(sum_op{}, sum2, two);
    mm->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
819
    auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
820
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
821
    EXPECT(bool{r.result == mm->end()});
Paul's avatar
Paul committed
822
823
824
825
826
}

TEST_CASE(match_skip_output6)
{
    migraphx::program p;
827
828
829
830
831
832
833
834
835

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto minus = mm->add_instruction(minus_op{}, two, one);
    auto sum1  = mm->add_instruction(sum_op{}, minus, two);
    auto sum2  = mm->add_instruction(sum_op{}, sum1, one);
    auto sum3  = mm->add_instruction(sum_op{}, sum2, two);
    mm->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
836
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
837
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
838
839
840
841
842
843
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output7)
{
    migraphx::program p;
844
845
846
847
848
849
850
851

    auto* mm    = p.get_main_module();
    auto one    = mm->add_literal(1);
    auto two    = mm->add_literal(2);
    auto minus1 = mm->add_instruction(minus_op{}, two, one);
    auto minus2 = mm->add_instruction(minus_op{}, two, minus1);
    auto sum    = mm->add_instruction(sum_op{}, one, minus2);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
852
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("minus")));
853
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
854
855
856
    EXPECT(bool{r.result == minus1});
}

Paul's avatar
Paul committed
857
TEST_CASE(match_bind1)
Paul's avatar
Paul committed
858
{
Paul's avatar
Paul committed
859
    migraphx::program p;
860
861
862
863
864
865

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum  = mm->add_instruction(sum_op{}, one, two);
    auto pass = mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
866
    auto m    = match::name("pass")(
Paul's avatar
Paul committed
867
868
869
                 match::args(match::name("sum")(match::args(match::name("@literal").bind("one"),
                                                            match::name("@literal").bind("two")))
                                 .bind("sum")),
Paul's avatar
Paul committed
870
                 match::standard_shape())
Paul's avatar
Paul committed
871
                 .bind("pass");
872
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
873
874
875
876
877
878
879
    EXPECT(bool{r.instructions.at("one") == one});
    EXPECT(bool{r.instructions.at("two") == two});
    EXPECT(bool{r.instructions.at("sum") == sum});
    EXPECT(bool{r.instructions.at("pass") == pass});
    EXPECT(bool{r.result == pass});
}

Paul Fultz II's avatar
Paul Fultz II committed
880
881
882
TEST_CASE(match_has_value1)
{
    migraphx::program p;
883
884
885
886
887
888
889

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
890
    auto m = match::has_value(1);
891
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
892
893
894
895
896
897
    EXPECT(bool{r.result == one});
}

TEST_CASE(match_has_value2)
{
    migraphx::program p;
898
899
900
901
902
903
904

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
905
    auto m = match::has_value(2);
906
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
907
908
909
910
911
912
    EXPECT(bool{r.result == two});
}

TEST_CASE(match_has_value3)
{
    migraphx::program p;
913
914
915
916
917
918
919

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
920
    auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(2)));
921
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
922
923
924
925
926
927
    EXPECT(bool{r.result == sum1});
}

TEST_CASE(match_has_value4)
{
    migraphx::program p;
928
929
930
931
932
933
934

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
935
    auto m = match::has_value(3);
936
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
937
    EXPECT(bool{r.result == mm->end()});
Paul Fultz II's avatar
Paul Fultz II committed
938
939
940
941
942
}

TEST_CASE(match_has_value5)
{
    migraphx::program p;
943
944
945
946
947
948
949

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
950
    auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(3)));
951
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
952
    EXPECT(bool{r.result == mm->end()});
Paul Fultz II's avatar
Paul Fultz II committed
953
954
955
956
957
}

TEST_CASE(match_has_value6)
{
    migraphx::program p;
958
959
960
961
962
963
964

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
965
    auto m = match::name("sum")(match::args(match::has_value(2), match::has_value(1)));
966
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
967
    EXPECT(bool{r.result == mm->end()});
Paul Fultz II's avatar
Paul Fultz II committed
968
969
970
971
972
}

TEST_CASE(match_tree1)
{
    migraphx::program p;
973
974
975
976
977
978
979
980

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, one, two);
    auto sum2  = mm->add_instruction(sum_op{}, sum1, three);
    mm->add_instruction(pass_op{}, sum2);
981
982
    auto m = match::tree(
        match::name("sum"), match::has_value(1), match::has_value(2), match::has_value(3));
983
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
984
985
986
987
988
989
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_tree2)
{
    migraphx::program p;
990
991
992
993
994
995
996
997

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, one, two);
    auto sum2  = mm->add_instruction(sum_op{}, sum1, three);
    mm->add_instruction(pass_op{}, sum2);
998
999
    auto m = match::tree(
        match::name("sum"), match::has_value(2), match::has_value(1), match::has_value(3));
1000
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
1001
    EXPECT(bool{r.result == mm->end()});
Paul Fultz II's avatar
Paul Fultz II committed
1002
1003
1004
1005
1006
}

TEST_CASE(match_tree3)
{
    migraphx::program p;
1007
1008
1009
1010
1011
1012
1013
1014

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, one, two);
    auto sum2  = mm->add_instruction(sum_op{}, three, sum1);
    mm->add_instruction(pass_op{}, sum2);
1015
1016
    auto m = match::tree(
        match::name("sum"), match::has_value(3), match::has_value(1), match::has_value(2));
1017
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
1018
1019
1020
1021
1022
1023
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_tree4)
{
    migraphx::program p;
1024
1025
1026
1027
1028
1029
1030
1031

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, one, two);
    auto sum2  = mm->add_instruction(sum_op{}, sum1, three);
    mm->add_instruction(pass_op{}, sum2);
1032
1033
1034
1035
1036
    auto m = match::tree(match::name("sum"),
                         match::has_value(1),
                         match::has_value(2),
                         match::has_value(3),
                         match::has_value(4));
1037
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
1038
    EXPECT(bool{r.result == mm->end()});
Paul Fultz II's avatar
Paul Fultz II committed
1039
1040
1041
1042
1043
}

TEST_CASE(match_tree5)
{
    migraphx::program p;
1044
1045
1046
1047
1048
1049
1050
1051

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, one, two);
    auto sum2  = mm->add_instruction(sum_op{}, sum1, three);
    mm->add_instruction(pass_op{}, sum2);
1052
    auto m = match::tree(match::name("sum"), match::has_value(2), match::has_value(3));
1053
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
1054
    EXPECT(bool{r.result == mm->end()});
Paul Fultz II's avatar
Paul Fultz II committed
1055
1056
1057
1058
1059
}

TEST_CASE(match_tree6)
{
    migraphx::program p;
1060
1061
1062
1063
1064
1065
1066
1067

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, one, two);
    auto sum2  = mm->add_instruction(sum_op{}, sum1, three);
    mm->add_instruction(pass_op{}, sum2);
1068
    auto m = match::tree(match::name("sum"), match::has_value(1), match::has_value(3));
1069
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
1070
    EXPECT(bool{r.result == mm->end()});
Paul Fultz II's avatar
Paul Fultz II committed
1071
1072
1073
1074
1075
}

TEST_CASE(match_unordered_tree1)
{
    migraphx::program p;
1076
1077
1078
1079
1080
1081
1082
1083

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, one, two);
    auto sum2  = mm->add_instruction(sum_op{}, sum1, three);
    mm->add_instruction(pass_op{}, sum2);
1084
1085
    auto m = match::unordered_tree(
        match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
1086
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
1087
1088
1089
1090
1091
1092
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_unordered_tree2)
{
    migraphx::program p;
1093
1094
1095
1096
1097
1098
1099
1100

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, one, two);
    auto sum2  = mm->add_instruction(sum_op{}, three, sum1);
    mm->add_instruction(pass_op{}, sum2);
1101
1102
    auto m = match::unordered_tree(
        match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
1103
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
1104
1105
1106
1107
1108
1109
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_unordered_tree3)
{
    migraphx::program p;
1110
1111
1112
1113
1114
1115
1116
1117

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, two, one);
    auto sum2  = mm->add_instruction(sum_op{}, sum1, three);
    mm->add_instruction(pass_op{}, sum2);
1118
1119
    auto m = match::unordered_tree(
        match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
1120
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
1121
1122
1123
1124
1125
1126
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_unordered_tree4)
{
    migraphx::program p;
1127
1128
1129
1130
1131
1132
1133
1134

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, one, two);
    auto sum2  = mm->add_instruction(sum_op{}, sum1, three);
    mm->add_instruction(pass_op{}, sum2);
1135
1136
    auto m = match::unordered_tree(
        match::name("sum"), match::has_value(4), match::has_value(2), match::has_value(1));
1137
    auto r = find_match(*mm, m);
Shucai Xiao's avatar
Shucai Xiao committed
1138
    EXPECT(bool{r.result == mm->end()});
Paul Fultz II's avatar
Paul Fultz II committed
1139
1140
}

Paul's avatar
Paul committed
1141
1142
struct match_find_sum
{
Paul's avatar
Paul committed
1143
    migraphx::instruction_ref ins;
Paul's avatar
Paul committed
1144
    auto matcher() const { return match::name("sum"); }
Paul's avatar
Paul committed
1145

1146
    void apply(migraphx::module&, const match::matcher_result& r) const
Paul's avatar
Paul committed
1147
1148
1149
    {
        EXPECT(bool{r.result == ins});
    }
Paul's avatar
Paul committed
1150
1151
1152
1153
};

struct match_find_literal
{
Paul's avatar
Paul committed
1154
    migraphx::instruction_ref ins;
Paul's avatar
Paul committed
1155
    auto matcher() const { return match::name("@literal"); }
Paul's avatar
Paul committed
1156

1157
    void apply(migraphx::module&, const match::matcher_result& r) const
Paul's avatar
Paul committed
1158
1159
1160
1161
1162
1163
    {
        EXPECT(bool{r.result != ins});
        EXPECT(r.result->name() == "@literal");
    }
};

Paul's avatar
Paul committed
1164
TEST_CASE(match_finder)
Paul's avatar
Paul committed
1165
{
Paul's avatar
Paul committed
1166
    migraphx::program p;
1167
1168
1169
1170
1171
1172
1173

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
    match::find_matches(*mm, match_find_sum{sum}, match_find_literal{sum});
Paul's avatar
Paul committed
1174
1175
}

Paul's avatar
Paul committed
1176
int main(int argc, const char* argv[]) { test::run(argc, argv); }