matcher.cpp 36.7 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/matcher.hpp>
#include <migraphx/iterator_for.hpp>
Paul's avatar
Paul committed
3
4
5
#include <test.hpp>
#include <basic_ops.hpp>

Paul's avatar
Paul committed
6
namespace match = migraphx::match;
Paul's avatar
Paul committed
7

Paul's avatar
Paul committed
8
9
MIGRAPHX_PRED_MATCHER(throws, migraphx::instruction_ref) { MIGRAPHX_THROW("Matcher throws"); }

Paul's avatar
Paul committed
10
template <class M>
11
migraphx::match::matcher_result find_match(migraphx::module& modl, M&& m)
Paul's avatar
Paul committed
12
{
Paul's avatar
Paul committed
13
    migraphx::match::matcher_result result;
14
    for(auto ins : migraphx::iterator_for(modl))
Paul's avatar
Paul committed
15
    {
16
17
        result = migraphx::match::match_instruction(modl, ins, m);
        if(result.result != modl.end())
Paul's avatar
Paul committed
18
19
20
21
22
23
24
            return result;
    }
    return result;
}

void match1()
{
Paul's avatar
Paul committed
25
    migraphx::program p;
26
27
28
29
30

    auto* mm = p.get_main_module();
    auto l   = mm->add_literal(1);
    auto m   = match::standard_shape();
    auto r   = find_match(*mm, m);
Paul's avatar
Paul committed
31
32
33
    EXPECT(bool{r.result == l});
}

Paul's avatar
Paul committed
34
TEST_CASE(match_name1)
Paul's avatar
Paul committed
35
{
Paul's avatar
Paul committed
36
    migraphx::program p;
37
38
39
40
41
42

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
43
    auto m = match::name("sum");
44
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
45
46
47
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
48
TEST_CASE(match_name2)
Paul's avatar
Paul committed
49
{
Paul's avatar
Paul committed
50
    migraphx::program p;
51
52
53
54
55
56

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
57
    auto m = match::name("min");
58
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
59
60
61
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
62
TEST_CASE(match_name3)
Paul's avatar
Paul committed
63
{
Paul's avatar
Paul committed
64
    migraphx::program p;
65
66
67
68
69
70

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
71
    auto m = match::name("sum")(match::standard_shape());
72
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
73
74
75
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
76
TEST_CASE(match_arg1)
Paul's avatar
Paul committed
77
{
Paul's avatar
Paul committed
78
    migraphx::program p;
79
80
81
82
83
84

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
85
    auto m = match::name("sum")(match::arg(0)(match::name("@literal")), match::standard_shape());
86
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
87
88
89
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
90
TEST_CASE(match_arg2)
Paul's avatar
Paul committed
91
{
Paul's avatar
Paul committed
92
    migraphx::program p;
93
94
95
96
97
98

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
99
    auto m = match::name("sum")(match::arg(0)(match::name("sum")), match::standard_shape());
100
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
101
102
103
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
104
TEST_CASE(match_arg3)
Paul's avatar
Paul committed
105
{
Paul's avatar
Paul committed
106
    migraphx::program p;
107
108
109
110
111
112

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
113
    auto m = match::name("sum")(match::arg(1)(match::name("@literal")), match::standard_shape());
114
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
115
116
117
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
118
TEST_CASE(match_arg4)
Paul's avatar
Paul committed
119
{
Paul's avatar
Paul committed
120
    migraphx::program p;
121
122
123
124
125
126

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum  = mm->add_instruction(sum_op{}, one, two);
    auto pass = mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
127
    auto m    = match::name("pass")(match::arg(0)(match::name("sum")), match::standard_shape());
128
    auto r    = find_match(*mm, m);
Paul's avatar
Paul committed
129
130
131
    EXPECT(bool{r.result == pass});
}

Paul's avatar
Paul committed
132
TEST_CASE(match_arg5)
Paul's avatar
Paul committed
133
{
Paul's avatar
Paul committed
134
    migraphx::program p;
135
136
137
138
139
140

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
141
    auto m = match::name("pass")(match::arg(1)(match::name("sum")), match::standard_shape());
142
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
143
144
145
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
146
TEST_CASE(match_arg6)
Paul's avatar
Paul committed
147
{
Paul's avatar
Paul committed
148
    migraphx::program p;
149
150
151
152
153
154

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
155
    auto m = match::name("sum")(match::arg(0)(match::name("@literal")));
156
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
157
158
159
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
160
TEST_CASE(match_arg7)
Paul's avatar
Paul committed
161
{
Paul's avatar
Paul committed
162
    migraphx::program p;
163
164
165
166
167
168

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
169
    auto m = match::name("sum")(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
170
                                match::arg(1)(match::name("@literal")));
171
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
172
173
174
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
175
176
177
TEST_CASE(match_arg8)
{
    migraphx::program p;
178
179
180
181
182
183

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
184
    auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
185
                                              match::arg(1)(match::name("@literal"))),
Paul's avatar
Paul committed
186
                                match::standard_shape());
187
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
188
189
190
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
191
192
193
TEST_CASE(match_nargs1)
{
    migraphx::program p;
194
195
196
197
198
199

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
200
    auto m = match::name("sum")(match::nargs(2));
201
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
202
203
204
205
206
207
    EXPECT(bool{r.result == sum});
}

TEST_CASE(match_nargs2)
{
    migraphx::program p;
208
209
210
211
212
213

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
214
    auto m = match::name("sum")(match::nargs(2), match::standard_shape());
215
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
216
217
218
219
220
221
    EXPECT(bool{r.result == sum});
}

TEST_CASE(match_nargs3)
{
    migraphx::program p;
222
223
224
225
226
227

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
228
    auto m = match::name("sum")(match::all_of(match::nargs(2)));
229
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
230
231
232
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
233
TEST_CASE(match_args1)
Paul's avatar
Paul committed
234
{
Paul's avatar
Paul committed
235
    migraphx::program p;
236
237
238
239
240
241

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
242
243
    auto m = match::name("sum")(match::args(match::name("@literal"), match::name("@literal")),
                                match::standard_shape());
244
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
245
246
247
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
248
TEST_CASE(match_args2)
Paul's avatar
Paul committed
249
{
Paul's avatar
Paul committed
250
    migraphx::program p;
251
252
253
254
255
256

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
257
258
    auto m = match::name("sum")(match::args(match::name("@literal"), match::name("sum")),
                                match::standard_shape());
259
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
260
261
262
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
263
TEST_CASE(match_args3)
Paul's avatar
Paul committed
264
{
Paul's avatar
Paul committed
265
    migraphx::program p;
266
267
268
269
270
271

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
272
    auto m = match::name("sum")(match::args(match::name("@literal")), match::standard_shape());
273
    auto r = find_match(*mm, m);
274
    EXPECT(bool{r.result == p.end()});
Paul's avatar
Paul committed
275
276
}

Paul's avatar
Paul committed
277
TEST_CASE(match_args4)
Paul's avatar
Paul committed
278
{
Paul's avatar
Paul committed
279
    migraphx::program p;
280
281
282
283
284
285
286

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
287
288
    auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
                                match::standard_shape());
289
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
290
291
292
    EXPECT(bool{r.result == sum2});
}

Paul's avatar
Paul committed
293
TEST_CASE(match_args5)
Paul's avatar
Paul committed
294
{
Paul's avatar
Paul committed
295
    migraphx::program p;
296
297
298
299
300
301

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
302
303
    auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
                                match::standard_shape());
304
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
305
306
307
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
308
TEST_CASE(match_args6)
Paul's avatar
Paul committed
309
{
Paul's avatar
Paul committed
310
    migraphx::program p;
311
312
313
314
315
316

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum  = mm->add_instruction(sum_op{}, one, two);
    auto pass = mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
317
    auto m    = match::name("pass")(match::args(match::name("sum")), match::standard_shape());
318
    auto r    = find_match(*mm, m);
Paul's avatar
Paul committed
319
320
321
    EXPECT(bool{r.result == pass});
}

Paul's avatar
Paul committed
322
TEST_CASE(match_args7)
Paul's avatar
Paul committed
323
{
Paul's avatar
Paul committed
324
    migraphx::program p;
325
326
327
328
329
330

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum  = mm->add_instruction(sum_op{}, one, two);
    auto pass = mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
331
    auto m    = match::name("pass")(match::args(match::name("sum")(match::args(
Paul's avatar
Paul committed
332
333
                                     match::name("@literal"), match::name("@literal")))),
                                 match::standard_shape());
334
    auto r    = find_match(*mm, m);
Paul's avatar
Paul committed
335
336
337
    EXPECT(bool{r.result == pass});
}

Paul's avatar
Paul committed
338
TEST_CASE(match_either_args1)
Paul's avatar
Paul committed
339
{
Paul's avatar
Paul committed
340
    migraphx::program p;
341
342
343
344
345
346
347

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
348
349
    auto m =
        match::name("sum")(match::either_arg(0, 1)(match::name("sum"), match::name("@literal")));
350
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
351
352
353
    EXPECT(bool{r.result == sum2});
}

Paul's avatar
Paul committed
354
TEST_CASE(match_either_args2)
Paul's avatar
Paul committed
355
{
Paul's avatar
Paul committed
356
    migraphx::program p;
357
358
359
360
361
362
363

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
364
365
    auto m =
        match::name("sum")(match::either_arg(0, 1)(match::name("@literal"), match::name("sum")));
366
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
367
368
369
    EXPECT(bool{r.result == sum2});
}

Paul's avatar
Paul committed
370
TEST_CASE(match_either_args3)
Paul's avatar
Paul committed
371
{
Paul's avatar
Paul committed
372
    migraphx::program p;
373
374
375
376
377
378
379

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
380
381
    auto m =
        match::name("sum")(match::either_arg(0, 1)(match::name("pass"), match::name("@literal")));
382
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
383
384
385
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
386
387
388
TEST_CASE(match_either_args_any1)
{
    migraphx::program p;
389
390
391
392
393
394
395

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
396
397
    auto m =
        match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y")));
398
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
399
400
401
402
403
404
405
    EXPECT(bool{r.result == sum1});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

TEST_CASE(match_either_args_any2)
{
    migraphx::program p;
406
407
408
409
410
411
412

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
413
414
    auto m = match::name("sum")(
        match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y")));
415
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
416
417
418
419
420
421
422
    EXPECT(bool{r.result == sum1});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

TEST_CASE(match_either_args_any3)
{
    migraphx::program p;
423
424
425
426
427
428
429

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
430
431
    auto m = match::name("sum")(
        match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y")));
432
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
433
434
435
436
437
438
439
    EXPECT(bool{r.result == sum1});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

TEST_CASE(match_either_args_any4)
{
    migraphx::program p;
440
441
442
443
444
445
446

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
447
448
    auto m = match::name("sum")(
        match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y")));
449
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
450
451
452
453
454
455
456
    EXPECT(bool{r.result == sum2});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

TEST_CASE(match_either_args_any5)
{
    migraphx::program p;
457
458
459
460
461
462
463

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
464
465
    auto m = match::name("sum")(
        match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y")));
466
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
467
468
469
470
    EXPECT(bool{r.result == sum2});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

Paul's avatar
Paul committed
471
TEST_CASE(match_all_of1)
Paul's avatar
Paul committed
472
{
Paul's avatar
Paul committed
473
    migraphx::program p;
474
475
476
477
478
479

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
480
    auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
481
                                              match::arg(1)(match::name("@literal"))));
482
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
483
484
485
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
486
TEST_CASE(match_all_of2)
Paul's avatar
Paul committed
487
{
Paul's avatar
Paul committed
488
    migraphx::program p;
489
490
491
492
493
494

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
495
496
    auto m = match::name("sum")(
        match::all_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
497
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
498
499
500
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
501
502
503
TEST_CASE(match_all_of3)
{
    migraphx::program p;
504
505
506
507
508
509

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
510
511
    auto m = match::name("sum")(match::all_of(match::all_of(
        match::arg(0)(match::name("@literal")), match::arg(1)(match::name("@literal")))));
512
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
513
514
515
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
516
517
518
TEST_CASE(match_lazy_any_of)
{
    migraphx::program p;
519
520
521
522

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    mm->add_instruction(pass_op{}, one);
Paul's avatar
Paul committed
523
    auto m = match::any_of(match::any(), throws());
524
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
525
526
527
528
529
530
    EXPECT(bool{r.result == one});
}

TEST_CASE(match_lazy_all_of)
{
    migraphx::program p;
531
532
533
534

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    mm->add_instruction(pass_op{}, one);
Paul's avatar
Paul committed
535
    auto m = match::all_of(match::none(), throws());
536
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
537
538
539
540
541
542
    EXPECT(bool{r.result == p.end()});
}

TEST_CASE(match_lazy_none_of)
{
    migraphx::program p;
543
544
545
546

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    mm->add_instruction(pass_op{}, one);
Paul's avatar
Paul committed
547
    auto m = match::none_of(match::any(), throws());
548
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
549
550
551
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
552
TEST_CASE(match_any_of1)
Paul's avatar
Paul committed
553
{
Paul's avatar
Paul committed
554
    migraphx::program p;
555
556
557
558
559
560

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
561
562
    auto m = match::name("sum")(
        match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
563
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
564
565
566
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
567
TEST_CASE(match_any_of2)
Paul's avatar
Paul committed
568
{
Paul's avatar
Paul committed
569
    migraphx::program p;
570
571
572
573
574
575

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
576
577
    auto m = match::name("sum")(
        match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
578
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
579
580
581
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
582
583
584
TEST_CASE(match_any_of_lazy1)
{
    migraphx::program p;
585
586
587
588
589
590

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
591
    auto m = match::name("sum")(
Paul's avatar
Paul committed
592
593
        match::any_of(match::args(match::any(), match::any()).bind("x"),
                      match::args(match::name("sum"), match::name("sum")).bind("y")));
594
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
595
596
597
598
599
600
601
602
603
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x"));
    EXPECT(bool{r.instructions["x"] == sum});
    EXPECT(not migraphx::contains(r.instructions, "y"));
}

TEST_CASE(match_any_of_lazy2)
{
    migraphx::program p;
604
605
606
607
608
609

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
610
    auto m = match::name("sum")(
Paul's avatar
Paul committed
611
612
        match::any_of(match::args(match::name("@literal"), match::name("@literal")).bind("x"),
                      match::args(match::any(), match::any()).bind("y")));
613
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
614
615
616
617
618
619
620
621
622
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x"));
    EXPECT(bool{r.instructions["x"] == sum});
    EXPECT(not migraphx::contains(r.instructions, "y"));
}

TEST_CASE(match_any_of_lazy3)
{
    migraphx::program p;
623
624
625
626
627
628

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
629
    auto m = match::name("sum")(
Paul's avatar
Paul committed
630
631
        match::any_of(match::args(match::any(), match::any()).bind("x"),
                      match::args(match::name("@literal"), match::name("@literal")).bind("y")));
632
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
633
634
635
636
637
638
639
640
641
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x"));
    EXPECT(bool{r.instructions["x"] == sum});
    EXPECT(not migraphx::contains(r.instructions, "y"));
}

TEST_CASE(match_any_of_lazy4)
{
    migraphx::program p;
642
643
644
645
646
647

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
648
649
650
    auto m = match::name("sum")(match::any_of(
        match::args(match::name("@literal").bind("x1"), match::name("@literal").bind("y1")),
        match::args(match::any().bind("x2"), match::any().bind("y2"))));
651
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
652
653
654
655
656
657
658
659
660
661
662
663
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x1"));
    EXPECT(migraphx::contains(r.instructions, "y1"));
    EXPECT(bool{r.instructions["x1"] == one});
    EXPECT(bool{r.instructions["y1"] == two});
    EXPECT(not migraphx::contains(r.instructions, "x2"));
    EXPECT(not migraphx::contains(r.instructions, "y2"));
}

TEST_CASE(match_any_of_lazy5)
{
    migraphx::program p;
664
665
666
667
668
669

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
670
671
672
    auto m = match::name("sum")(match::any_of(
        match::args(match::any().bind("x1"), match::any().bind("y1")),
        match::args(match::name("@literal").bind("x2"), match::name("@literal").bind("y2"))));
673
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
674
675
676
677
678
679
680
681
682
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x1"));
    EXPECT(migraphx::contains(r.instructions, "y1"));
    EXPECT(bool{r.instructions["x1"] == one});
    EXPECT(bool{r.instructions["y1"] == two});
    EXPECT(not migraphx::contains(r.instructions, "x2"));
    EXPECT(not migraphx::contains(r.instructions, "y2"));
}

Paul's avatar
Paul committed
683
TEST_CASE(match_none_of1)
Paul's avatar
Paul committed
684
{
Paul's avatar
Paul committed
685
    migraphx::program p;
686
687
688
689
690
691

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
692
693
    auto m = match::name("sum")(
        match::none_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
694
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
695
696
697
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
698
TEST_CASE(match_none_of2)
Paul's avatar
Paul committed
699
{
Paul's avatar
Paul committed
700
    migraphx::program p;
701
702
703
704
705
706

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
707
    auto m = match::name("sum")(match::none_of(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
708
                                               match::arg(1)(match::name("@literal"))));
709
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
710
711
712
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
713
714
715
TEST_CASE(match_output1)
{
    migraphx::program p;
716
717
718
719
720
721
722

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto minus = mm->add_instruction(minus_op{}, two, one);
    auto sum   = mm->add_instruction(sum_op{}, minus, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
723
    auto m = match::name("minus")(match::output(match::name("sum")));
724
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
725
726
727
728
729
730
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_output2)
{
    migraphx::program p;
731
732
733
734
735
736
737

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto minus = mm->add_instruction(minus_op{}, two, one);
    auto sum   = mm->add_instruction(sum_op{}, minus, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
738
    auto m = match::name("@literal")(match::output(match::name("sum")));
739
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
740
741
742
743
744
745
    EXPECT(bool{r.result == p.end()});
}

TEST_CASE(match_skip_output1)
{
    migraphx::program p;
746
747
748
749
750
751
752

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto minus = mm->add_instruction(minus_op{}, two, one);
    auto sum   = mm->add_instruction(sum_op{}, minus, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
753
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
754
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
755
756
757
758
759
760
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output2)
{
    migraphx::program p;
761
762
763
764
765
766
767
768

    auto* mm        = p.get_main_module();
    auto one        = mm->add_literal(1);
    auto two        = mm->add_literal(2);
    auto minus      = mm->add_instruction(minus_op{}, two, one);
    auto minus_pass = mm->add_instruction(pass_op{}, minus);
    auto sum        = mm->add_instruction(sum_op{}, minus_pass, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
769
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
770
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
771
772
773
774
775
776
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output3)
{
    migraphx::program p;
777
778
779
780
781
782
783
784
785
786

    auto* mm         = p.get_main_module();
    auto one         = mm->add_literal(1);
    auto two         = mm->add_literal(2);
    auto minus       = mm->add_instruction(minus_op{}, two, one);
    auto minus_pass1 = mm->add_instruction(pass_op{}, minus);
    auto minus_pass2 = mm->add_instruction(pass_op{}, minus_pass1);
    auto minus_pass3 = mm->add_instruction(pass_op{}, minus_pass2);
    auto sum         = mm->add_instruction(sum_op{}, minus_pass3, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
787
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
788
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
789
790
791
792
793
794
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output4)
{
    migraphx::program p;
795
796
797
798
799
800
801

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto pass = mm->add_instruction(pass_op{}, one);
    auto sum  = mm->add_instruction(sum_op{}, pass, two);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
802
    auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
803
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
804
805
806
807
808
809
    EXPECT(bool{r.result == two});
}

TEST_CASE(match_skip_output5)
{
    migraphx::program p;
810
811
812
813
814
815
816
817
818

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto pass = mm->add_instruction(pass_op{}, one);
    auto sum1 = mm->add_instruction(sum_op{}, pass, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, one);
    auto sum3 = mm->add_instruction(sum_op{}, sum2, two);
    mm->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
819
    auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
820
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
821
822
823
824
825
826
    EXPECT(bool{r.result == p.end()});
}

TEST_CASE(match_skip_output6)
{
    migraphx::program p;
827
828
829
830
831
832
833
834
835

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto minus = mm->add_instruction(minus_op{}, two, one);
    auto sum1  = mm->add_instruction(sum_op{}, minus, two);
    auto sum2  = mm->add_instruction(sum_op{}, sum1, one);
    auto sum3  = mm->add_instruction(sum_op{}, sum2, two);
    mm->add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
836
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
837
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
838
839
840
841
842
843
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output7)
{
    migraphx::program p;
844
845
846
847
848
849
850
851

    auto* mm    = p.get_main_module();
    auto one    = mm->add_literal(1);
    auto two    = mm->add_literal(2);
    auto minus1 = mm->add_instruction(minus_op{}, two, one);
    auto minus2 = mm->add_instruction(minus_op{}, two, minus1);
    auto sum    = mm->add_instruction(sum_op{}, one, minus2);
    mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
852
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("minus")));
853
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
854
855
856
    EXPECT(bool{r.result == minus1});
}

Paul's avatar
Paul committed
857
TEST_CASE(match_bind1)
Paul's avatar
Paul committed
858
{
Paul's avatar
Paul committed
859
    migraphx::program p;
860
861
862
863
864
865

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum  = mm->add_instruction(sum_op{}, one, two);
    auto pass = mm->add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
866
    auto m    = match::name("pass")(
Paul's avatar
Paul committed
867
868
869
                 match::args(match::name("sum")(match::args(match::name("@literal").bind("one"),
                                                            match::name("@literal").bind("two")))
                                 .bind("sum")),
Paul's avatar
Paul committed
870
                 match::standard_shape())
Paul's avatar
Paul committed
871
                 .bind("pass");
872
    auto r = find_match(*mm, m);
Paul's avatar
Paul committed
873
874
875
876
877
878
879
    EXPECT(bool{r.instructions.at("one") == one});
    EXPECT(bool{r.instructions.at("two") == two});
    EXPECT(bool{r.instructions.at("sum") == sum});
    EXPECT(bool{r.instructions.at("pass") == pass});
    EXPECT(bool{r.result == pass});
}

Paul Fultz II's avatar
Paul Fultz II committed
880
881
882
TEST_CASE(match_has_value1)
{
    migraphx::program p;
883
884
885
886
887
888
889

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
890
    auto m = match::has_value(1);
891
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
892
893
894
895
896
897
    EXPECT(bool{r.result == one});
}

TEST_CASE(match_has_value2)
{
    migraphx::program p;
898
899
900
901
902
903
904

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
905
    auto m = match::has_value(2);
906
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
907
908
909
910
911
912
    EXPECT(bool{r.result == two});
}

TEST_CASE(match_has_value3)
{
    migraphx::program p;
913
914
915
916
917
918
919

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
920
    auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(2)));
921
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
922
923
924
925
926
927
    EXPECT(bool{r.result == sum1});
}

TEST_CASE(match_has_value4)
{
    migraphx::program p;
928
929
930
931
932
933
934

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
935
    auto m = match::has_value(3);
936
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
937
938
939
940
941
942
    EXPECT(bool{r.result == p.end()});
}

TEST_CASE(match_has_value5)
{
    migraphx::program p;
943
944
945
946
947
948
949

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
950
    auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(3)));
951
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
952
953
954
955
956
957
    EXPECT(bool{r.result == p.end()});
}

TEST_CASE(match_has_value6)
{
    migraphx::program p;
958
959
960
961
962
963
964

    auto* mm  = p.get_main_module();
    auto one  = mm->add_literal(1);
    auto two  = mm->add_literal(2);
    auto sum1 = mm->add_instruction(sum_op{}, one, two);
    auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
965
    auto m = match::name("sum")(match::args(match::has_value(2), match::has_value(1)));
966
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
967
968
969
970
971
972
    EXPECT(bool{r.result == p.end()});
}

TEST_CASE(match_tree1)
{
    migraphx::program p;
973
974
975
976
977
978
979
980

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, one, two);
    auto sum2  = mm->add_instruction(sum_op{}, sum1, three);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
981
    auto m = match::tree("sum", match::has_value(1), match::has_value(2), match::has_value(3));
982
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
983
984
985
986
987
988
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_tree2)
{
    migraphx::program p;
989
990
991
992
993
994
995
996

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, one, two);
    auto sum2  = mm->add_instruction(sum_op{}, sum1, three);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
997
    auto m = match::tree("sum", match::has_value(2), match::has_value(1), match::has_value(3));
998
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
999
1000
1001
1002
1003
1004
    EXPECT(bool{r.result == p.end()});
}

TEST_CASE(match_tree3)
{
    migraphx::program p;
1005
1006
1007
1008
1009
1010
1011
1012

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, one, two);
    auto sum2  = mm->add_instruction(sum_op{}, three, sum1);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
1013
    auto m = match::tree("sum", match::has_value(3), match::has_value(1), match::has_value(2));
1014
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
1015
1016
1017
1018
1019
1020
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_tree4)
{
    migraphx::program p;
1021
1022
1023
1024
1025
1026
1027
1028

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, one, two);
    auto sum2  = mm->add_instruction(sum_op{}, sum1, three);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
1029
1030
    auto m = match::tree(
        "sum", match::has_value(1), match::has_value(2), match::has_value(3), match::has_value(4));
1031
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
1032
1033
1034
1035
1036
1037
    EXPECT(bool{r.result == p.end()});
}

TEST_CASE(match_tree5)
{
    migraphx::program p;
1038
1039
1040
1041
1042
1043
1044
1045

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, one, two);
    auto sum2  = mm->add_instruction(sum_op{}, sum1, three);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
1046
    auto m = match::tree("sum", match::has_value(2), match::has_value(3));
1047
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
1048
1049
1050
1051
1052
1053
    EXPECT(bool{r.result == p.end()});
}

TEST_CASE(match_tree6)
{
    migraphx::program p;
1054
1055
1056
1057
1058
1059
1060
1061

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, one, two);
    auto sum2  = mm->add_instruction(sum_op{}, sum1, three);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
1062
    auto m = match::tree("sum", match::has_value(1), match::has_value(3));
1063
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
1064
1065
1066
1067
1068
1069
    EXPECT(bool{r.result == p.end()});
}

TEST_CASE(match_unordered_tree1)
{
    migraphx::program p;
1070
1071
1072
1073
1074
1075
1076
1077

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, one, two);
    auto sum2  = mm->add_instruction(sum_op{}, sum1, three);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
1078
1079
    auto m =
        match::unordered_tree("sum", match::has_value(3), match::has_value(2), match::has_value(1));
1080
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
1081
1082
1083
1084
1085
1086
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_unordered_tree2)
{
    migraphx::program p;
1087
1088
1089
1090
1091
1092
1093
1094

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, one, two);
    auto sum2  = mm->add_instruction(sum_op{}, three, sum1);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
1095
1096
    auto m =
        match::unordered_tree("sum", match::has_value(3), match::has_value(2), match::has_value(1));
1097
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
1098
1099
1100
1101
1102
1103
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_unordered_tree3)
{
    migraphx::program p;
1104
1105
1106
1107
1108
1109
1110
1111

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, two, one);
    auto sum2  = mm->add_instruction(sum_op{}, sum1, three);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
1112
1113
    auto m =
        match::unordered_tree("sum", match::has_value(3), match::has_value(2), match::has_value(1));
1114
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
1115
1116
1117
1118
1119
1120
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_unordered_tree4)
{
    migraphx::program p;
1121
1122
1123
1124
1125
1126
1127
1128

    auto* mm   = p.get_main_module();
    auto one   = mm->add_literal(1);
    auto two   = mm->add_literal(2);
    auto three = mm->add_literal(3);
    auto sum1  = mm->add_instruction(sum_op{}, one, two);
    auto sum2  = mm->add_instruction(sum_op{}, sum1, three);
    mm->add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
1129
1130
    auto m =
        match::unordered_tree("sum", match::has_value(4), match::has_value(2), match::has_value(1));
1131
    auto r = find_match(*mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
1132
1133
1134
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
1135
1136
struct match_find_sum
{
Paul's avatar
Paul committed
1137
    migraphx::instruction_ref ins;
Paul's avatar
Paul committed
1138
    auto matcher() const { return match::name("sum"); }
Paul's avatar
Paul committed
1139

1140
    void apply(migraphx::module&, const match::matcher_result& r) const
Paul's avatar
Paul committed
1141
1142
1143
    {
        EXPECT(bool{r.result == ins});
    }
Paul's avatar
Paul committed
1144
1145
1146
1147
};

struct match_find_literal
{
Paul's avatar
Paul committed
1148
    migraphx::instruction_ref ins;
Paul's avatar
Paul committed
1149
    auto matcher() const { return match::name("@literal"); }
Paul's avatar
Paul committed
1150

1151
    void apply(migraphx::module&, const match::matcher_result& r) const
Paul's avatar
Paul committed
1152
1153
1154
1155
1156
1157
    {
        EXPECT(bool{r.result != ins});
        EXPECT(r.result->name() == "@literal");
    }
};

Paul's avatar
Paul committed
1158
TEST_CASE(match_finder)
Paul's avatar
Paul committed
1159
{
Paul's avatar
Paul committed
1160
    migraphx::program p;
1161
1162
1163
1164
1165
1166
1167

    auto* mm = p.get_main_module();
    auto one = mm->add_literal(1);
    auto two = mm->add_literal(2);
    auto sum = mm->add_instruction(sum_op{}, one, two);
    mm->add_instruction(pass_op{}, sum);
    match::find_matches(*mm, match_find_sum{sum}, match_find_literal{sum});
Paul's avatar
Paul committed
1168
1169
}

Paul's avatar
Paul committed
1170
int main(int argc, const char* argv[]) { test::run(argc, argv); }