matcher.cpp 35.5 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/matcher.hpp>
#include <migraphx/iterator_for.hpp>
Paul's avatar
Paul committed
3
4
5
#include <test.hpp>
#include <basic_ops.hpp>

Paul's avatar
Paul committed
6
namespace match = migraphx::match;
Paul's avatar
Paul committed
7

Paul's avatar
Paul committed
8
9
MIGRAPHX_PRED_MATCHER(throws, migraphx::instruction_ref) { MIGRAPHX_THROW("Matcher throws"); }

Paul's avatar
Paul committed
10
11
void match1()
{
12
13
14
15
    migraphx::module mm;
    auto l = mm.add_literal(1);
    auto m = match::standard_shape();
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
16
17
18
    EXPECT(bool{r.result == l});
}

Paul's avatar
Paul committed
19
TEST_CASE(match_name1)
Paul's avatar
Paul committed
20
{
21
22
23
24
25
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
26
    auto m = match::name("sum");
27
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
28
29
30
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
31
TEST_CASE(match_name2)
Paul's avatar
Paul committed
32
{
33
34
35
36
37
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
38
    auto m = match::name("min");
39
40
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
41
42
}

Paul's avatar
Paul committed
43
TEST_CASE(match_name3)
Paul's avatar
Paul committed
44
{
45
46
47
48
49
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
50
    auto m = match::name("sum")(match::standard_shape());
51
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
52
53
54
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
55
TEST_CASE(match_arg1)
Paul's avatar
Paul committed
56
{
57
58
59
60
61
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
62
    auto m = match::name("sum")(match::arg(0)(match::name("@literal")), match::standard_shape());
63
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
64
65
66
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
67
TEST_CASE(match_arg2)
Paul's avatar
Paul committed
68
{
69
70
71
72
73
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
74
    auto m = match::name("sum")(match::arg(0)(match::name("sum")), match::standard_shape());
75
76
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
77
78
}

Paul's avatar
Paul committed
79
TEST_CASE(match_arg3)
Paul's avatar
Paul committed
80
{
81
82
83
84
85
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
86
    auto m = match::name("sum")(match::arg(1)(match::name("@literal")), match::standard_shape());
87
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
88
89
90
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
91
TEST_CASE(match_arg4)
Paul's avatar
Paul committed
92
{
93
94
95
96
97
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum  = mm.add_instruction(sum_op{}, one, two);
    auto pass = mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
98
    auto m    = match::name("pass")(match::arg(0)(match::name("sum")), match::standard_shape());
99
    auto r    = find_match(mm, m);
Paul's avatar
Paul committed
100
101
102
    EXPECT(bool{r.result == pass});
}

Paul's avatar
Paul committed
103
TEST_CASE(match_arg5)
Paul's avatar
Paul committed
104
{
105
106
107
108
109
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
110
    auto m = match::name("pass")(match::arg(1)(match::name("sum")), match::standard_shape());
111
112
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
113
114
}

Paul's avatar
Paul committed
115
TEST_CASE(match_arg6)
Paul's avatar
Paul committed
116
{
117
118
119
120
121
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
122
    auto m = match::name("sum")(match::arg(0)(match::name("@literal")));
123
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
124
125
126
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
127
TEST_CASE(match_arg7)
Paul's avatar
Paul committed
128
{
129
130
131
132
133
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
134
    auto m = match::name("sum")(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
135
                                match::arg(1)(match::name("@literal")));
136
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
137
138
139
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
140
141
TEST_CASE(match_arg8)
{
142
143
144
145
146
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
147
    auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
148
                                              match::arg(1)(match::name("@literal"))),
Paul's avatar
Paul committed
149
                                match::standard_shape());
150
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
151
152
153
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
154
155
TEST_CASE(match_nargs1)
{
156
157
158
159
160
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
161
    auto m = match::name("sum")(match::nargs(2));
162
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
163
164
165
166
167
    EXPECT(bool{r.result == sum});
}

TEST_CASE(match_nargs2)
{
168
169
170
171
172
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
173
    auto m = match::name("sum")(match::nargs(2), match::standard_shape());
174
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
175
176
177
178
179
    EXPECT(bool{r.result == sum});
}

TEST_CASE(match_nargs3)
{
180
181
182
183
184
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
185
    auto m = match::name("sum")(match::all_of(match::nargs(2)));
186
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
187
188
189
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
190
TEST_CASE(match_args1)
Paul's avatar
Paul committed
191
{
192
193
194
195
196
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
197
198
    auto m = match::name("sum")(match::args(match::name("@literal"), match::name("@literal")),
                                match::standard_shape());
199
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
200
201
202
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
203
TEST_CASE(match_args2)
Paul's avatar
Paul committed
204
{
205
206
207
208
209
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
210
211
    auto m = match::name("sum")(match::args(match::name("@literal"), match::name("sum")),
                                match::standard_shape());
212
213
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
214
215
}

Paul's avatar
Paul committed
216
TEST_CASE(match_args3)
Paul's avatar
Paul committed
217
{
218
219
220
221
222
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
223
    auto m = match::name("sum")(match::args(match::name("@literal")), match::standard_shape());
224
225
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
226
227
}

Paul's avatar
Paul committed
228
TEST_CASE(match_args4)
Paul's avatar
Paul committed
229
{
230
231
232
233
234
235
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
236
237
    auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
                                match::standard_shape());
238
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
239
240
241
    EXPECT(bool{r.result == sum2});
}

Paul's avatar
Paul committed
242
TEST_CASE(match_args5)
Paul's avatar
Paul committed
243
{
244
245
246
247
248
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
249
250
    auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
                                match::standard_shape());
251
252
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
253
254
}

Paul's avatar
Paul committed
255
TEST_CASE(match_args6)
Paul's avatar
Paul committed
256
{
257
258
259
260
261
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum  = mm.add_instruction(sum_op{}, one, two);
    auto pass = mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
262
    auto m    = match::name("pass")(match::args(match::name("sum")), match::standard_shape());
263
    auto r    = find_match(mm, m);
Paul's avatar
Paul committed
264
265
266
    EXPECT(bool{r.result == pass});
}

Paul's avatar
Paul committed
267
TEST_CASE(match_args7)
Paul's avatar
Paul committed
268
{
269
270
271
272
273
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum  = mm.add_instruction(sum_op{}, one, two);
    auto pass = mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
274
    auto m    = match::name("pass")(match::args(match::name("sum")(match::args(
Paul's avatar
Paul committed
275
276
                                     match::name("@literal"), match::name("@literal")))),
                                 match::standard_shape());
277
    auto r    = find_match(mm, m);
Paul's avatar
Paul committed
278
279
280
    EXPECT(bool{r.result == pass});
}

Paul's avatar
Paul committed
281
TEST_CASE(match_either_args1)
Paul's avatar
Paul committed
282
{
283
284
285
286
287
288
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
289
290
    auto m =
        match::name("sum")(match::either_arg(0, 1)(match::name("sum"), match::name("@literal")));
291
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
292
293
294
    EXPECT(bool{r.result == sum2});
}

Paul's avatar
Paul committed
295
TEST_CASE(match_either_args2)
Paul's avatar
Paul committed
296
{
297
298
299
300
301
302
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
303
304
    auto m =
        match::name("sum")(match::either_arg(0, 1)(match::name("@literal"), match::name("sum")));
305
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
306
307
308
    EXPECT(bool{r.result == sum2});
}

Paul's avatar
Paul committed
309
TEST_CASE(match_either_args3)
Paul's avatar
Paul committed
310
{
311
312
313
314
315
316
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
317
318
    auto m =
        match::name("sum")(match::either_arg(0, 1)(match::name("pass"), match::name("@literal")));
319
320
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
321
322
}

Paul's avatar
Paul committed
323
324
TEST_CASE(match_either_args_any1)
{
325
326
327
328
329
330
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
331
332
    auto m =
        match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y")));
333
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
334
335
336
337
338
339
    EXPECT(bool{r.result == sum1});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

TEST_CASE(match_either_args_any2)
{
340
341
342
343
344
345
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
346
347
    auto m = match::name("sum")(
        match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y")));
348
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
349
350
351
352
353
354
    EXPECT(bool{r.result == sum1});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

TEST_CASE(match_either_args_any3)
{
355
356
357
358
359
360
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
361
362
    auto m = match::name("sum")(
        match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y")));
363
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
364
365
366
367
368
369
    EXPECT(bool{r.result == sum1});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

TEST_CASE(match_either_args_any4)
{
370
371
372
373
374
375
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
376
377
    auto m = match::name("sum")(
        match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y")));
378
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
379
380
381
382
383
384
    EXPECT(bool{r.result == sum2});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

TEST_CASE(match_either_args_any5)
{
385
386
387
388
389
390
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
391
392
    auto m = match::name("sum")(
        match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y")));
393
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
394
395
396
397
    EXPECT(bool{r.result == sum2});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

Paul's avatar
Paul committed
398
TEST_CASE(match_all_of1)
Paul's avatar
Paul committed
399
{
400
401
402
403
404
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
405
    auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
406
                                              match::arg(1)(match::name("@literal"))));
407
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
408
409
410
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
411
TEST_CASE(match_all_of2)
Paul's avatar
Paul committed
412
{
413
414
415
416
417
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
418
419
    auto m = match::name("sum")(
        match::all_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
420
421
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
422
423
}

Paul's avatar
Paul committed
424
425
TEST_CASE(match_all_of3)
{
426
427
428
429
430
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
431
432
    auto m = match::name("sum")(match::all_of(match::all_of(
        match::arg(0)(match::name("@literal")), match::arg(1)(match::name("@literal")))));
433
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
434
435
436
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
437
438
TEST_CASE(match_lazy_any_of)
{
439
440
441
    migraphx::module mm;
    auto one = mm.add_literal(1);
    mm.add_instruction(pass_op{}, one);
Paul's avatar
Paul committed
442
    auto m = match::any_of(match::any(), throws());
443
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
444
445
446
447
448
    EXPECT(bool{r.result == one});
}

TEST_CASE(match_lazy_all_of)
{
449
450
451
    migraphx::module mm;
    auto one = mm.add_literal(1);
    mm.add_instruction(pass_op{}, one);
Paul's avatar
Paul committed
452
    auto m = match::all_of(match::none(), throws());
453
454
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
455
456
457
458
}

TEST_CASE(match_lazy_none_of)
{
459
460
461
    migraphx::module mm;
    auto one = mm.add_literal(1);
    mm.add_instruction(pass_op{}, one);
Paul's avatar
Paul committed
462
    auto m = match::none_of(match::any(), throws());
463
464
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
465
466
}

Paul's avatar
Paul committed
467
TEST_CASE(match_any_of1)
Paul's avatar
Paul committed
468
{
469
470
471
472
473
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
474
475
    auto m = match::name("sum")(
        match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
476
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
477
478
479
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
480
TEST_CASE(match_any_of2)
Paul's avatar
Paul committed
481
{
482
483
484
485
486
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
487
488
    auto m = match::name("sum")(
        match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
489
490
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
491
492
}

Paul's avatar
Paul committed
493
494
TEST_CASE(match_any_of_lazy1)
{
495
496
497
498
499
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
500
    auto m = match::name("sum")(
Paul's avatar
Paul committed
501
502
        match::any_of(match::args(match::any(), match::any()).bind("x"),
                      match::args(match::name("sum"), match::name("sum")).bind("y")));
503
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
504
505
506
507
508
509
510
511
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x"));
    EXPECT(bool{r.instructions["x"] == sum});
    EXPECT(not migraphx::contains(r.instructions, "y"));
}

TEST_CASE(match_any_of_lazy2)
{
512
513
514
515
516
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
517
    auto m = match::name("sum")(
Paul's avatar
Paul committed
518
519
        match::any_of(match::args(match::name("@literal"), match::name("@literal")).bind("x"),
                      match::args(match::any(), match::any()).bind("y")));
520
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
521
522
523
524
525
526
527
528
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x"));
    EXPECT(bool{r.instructions["x"] == sum});
    EXPECT(not migraphx::contains(r.instructions, "y"));
}

TEST_CASE(match_any_of_lazy3)
{
529
530
531
532
533
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
534
    auto m = match::name("sum")(
Paul's avatar
Paul committed
535
536
        match::any_of(match::args(match::any(), match::any()).bind("x"),
                      match::args(match::name("@literal"), match::name("@literal")).bind("y")));
537
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
538
539
540
541
542
543
544
545
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x"));
    EXPECT(bool{r.instructions["x"] == sum});
    EXPECT(not migraphx::contains(r.instructions, "y"));
}

TEST_CASE(match_any_of_lazy4)
{
546
547
548
549
550
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
551
552
553
    auto m = match::name("sum")(match::any_of(
        match::args(match::name("@literal").bind("x1"), match::name("@literal").bind("y1")),
        match::args(match::any().bind("x2"), match::any().bind("y2"))));
554
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
555
556
557
558
559
560
561
562
563
564
565
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x1"));
    EXPECT(migraphx::contains(r.instructions, "y1"));
    EXPECT(bool{r.instructions["x1"] == one});
    EXPECT(bool{r.instructions["y1"] == two});
    EXPECT(not migraphx::contains(r.instructions, "x2"));
    EXPECT(not migraphx::contains(r.instructions, "y2"));
}

TEST_CASE(match_any_of_lazy5)
{
566
567
568
569
570
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
571
572
573
    auto m = match::name("sum")(match::any_of(
        match::args(match::any().bind("x1"), match::any().bind("y1")),
        match::args(match::name("@literal").bind("x2"), match::name("@literal").bind("y2"))));
574
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
575
576
577
578
579
580
581
582
583
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x1"));
    EXPECT(migraphx::contains(r.instructions, "y1"));
    EXPECT(bool{r.instructions["x1"] == one});
    EXPECT(bool{r.instructions["y1"] == two});
    EXPECT(not migraphx::contains(r.instructions, "x2"));
    EXPECT(not migraphx::contains(r.instructions, "y2"));
}

Paul's avatar
Paul committed
584
TEST_CASE(match_none_of1)
Paul's avatar
Paul committed
585
{
586
587
588
589
590
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
591
592
    auto m = match::name("sum")(
        match::none_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
593
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
594
595
596
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
597
TEST_CASE(match_none_of2)
Paul's avatar
Paul committed
598
{
599
600
601
602
603
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
604
    auto m = match::name("sum")(match::none_of(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
605
                                               match::arg(1)(match::name("@literal"))));
606
607
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
608
609
}

Paul's avatar
Paul committed
610
611
TEST_CASE(match_output1)
{
612
613
614
615
616
617
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto minus = mm.add_instruction(minus_op{}, two, one);
    auto sum   = mm.add_instruction(sum_op{}, minus, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
618
    auto m = match::name("minus")(match::output(match::name("sum")));
619
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
620
621
622
623
624
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_output2)
{
625
626
627
628
629
630
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto minus = mm.add_instruction(minus_op{}, two, one);
    auto sum   = mm.add_instruction(sum_op{}, minus, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
631
    auto m = match::name("@literal")(match::output(match::name("sum")));
632
633
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
634
635
636
637
}

TEST_CASE(match_skip_output1)
{
638
639
640
641
642
643
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto minus = mm.add_instruction(minus_op{}, two, one);
    auto sum   = mm.add_instruction(sum_op{}, minus, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
644
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
645
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
646
647
648
649
650
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output2)
{
651
652
653
654
655
656
657
    migraphx::module mm;
    auto one        = mm.add_literal(1);
    auto two        = mm.add_literal(2);
    auto minus      = mm.add_instruction(minus_op{}, two, one);
    auto minus_pass = mm.add_instruction(pass_op{}, minus);
    auto sum        = mm.add_instruction(sum_op{}, minus_pass, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
658
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
659
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
660
661
662
663
664
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output3)
{
665
666
667
668
669
670
671
672
673
    migraphx::module mm;
    auto one         = mm.add_literal(1);
    auto two         = mm.add_literal(2);
    auto minus       = mm.add_instruction(minus_op{}, two, one);
    auto minus_pass1 = mm.add_instruction(pass_op{}, minus);
    auto minus_pass2 = mm.add_instruction(pass_op{}, minus_pass1);
    auto minus_pass3 = mm.add_instruction(pass_op{}, minus_pass2);
    auto sum         = mm.add_instruction(sum_op{}, minus_pass3, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
674
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
675
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
676
677
678
679
680
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output4)
{
681
682
683
684
685
686
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto pass = mm.add_instruction(pass_op{}, one);
    auto sum  = mm.add_instruction(sum_op{}, pass, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
687
    auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
688
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
689
690
691
692
693
    EXPECT(bool{r.result == two});
}

TEST_CASE(match_skip_output5)
{
694
695
696
697
698
699
700
701
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto pass = mm.add_instruction(pass_op{}, one);
    auto sum1 = mm.add_instruction(sum_op{}, pass, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, one);
    auto sum3 = mm.add_instruction(sum_op{}, sum2, two);
    mm.add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
702
    auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
703
704
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
705
706
707
708
}

TEST_CASE(match_skip_output6)
{
709
710
711
712
713
714
715
716
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto minus = mm.add_instruction(minus_op{}, two, one);
    auto sum1  = mm.add_instruction(sum_op{}, minus, two);
    auto sum2  = mm.add_instruction(sum_op{}, sum1, one);
    auto sum3  = mm.add_instruction(sum_op{}, sum2, two);
    mm.add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
717
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
718
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
719
720
721
722
723
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output7)
{
724
725
726
727
728
729
730
    migraphx::module mm;
    auto one    = mm.add_literal(1);
    auto two    = mm.add_literal(2);
    auto minus1 = mm.add_instruction(minus_op{}, two, one);
    auto minus2 = mm.add_instruction(minus_op{}, two, minus1);
    auto sum    = mm.add_instruction(sum_op{}, one, minus2);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
731
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("minus")));
732
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
733
734
735
    EXPECT(bool{r.result == minus1});
}

Paul's avatar
Paul committed
736
TEST_CASE(match_bind1)
Paul's avatar
Paul committed
737
{
738
739
740
741
742
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum  = mm.add_instruction(sum_op{}, one, two);
    auto pass = mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
743
    auto m    = match::name("pass")(
Paul's avatar
Paul committed
744
745
746
                 match::args(match::name("sum")(match::args(match::name("@literal").bind("one"),
                                                            match::name("@literal").bind("two")))
                                 .bind("sum")),
Paul's avatar
Paul committed
747
                 match::standard_shape())
Paul's avatar
Paul committed
748
                 .bind("pass");
749
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
750
751
752
753
754
755
756
    EXPECT(bool{r.instructions.at("one") == one});
    EXPECT(bool{r.instructions.at("two") == two});
    EXPECT(bool{r.instructions.at("sum") == sum});
    EXPECT(bool{r.instructions.at("pass") == pass});
    EXPECT(bool{r.result == pass});
}

757
TEST_CASE(match_bind_modules1)
Paul Fultz II's avatar
Paul Fultz II committed
758
759
{
    migraphx::program p;
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
    auto* mm    = p.get_main_module();
    auto one    = mm->add_literal(1);
    auto* child = p.create_module("child");
    auto two    = child->add_literal(2);
    auto sum    = child->add_instruction(sum_op{}, one, two);
    child->add_instruction(pass_op{}, sum);
    mm->add_instruction(mod_pass_op{}, {one}, {child});
    auto m = match::name("pass")(
                 match::args(match::name("sum")(match::args(match::name("@literal").bind("one"),
                                                            match::name("@literal").bind("two")))
                                 .bind("sum")),
                 match::standard_shape())
                 .bind("pass");
    auto r = find_match(*child, m);
    EXPECT(not migraphx::contains(r.instructions, "one"));
    EXPECT(not migraphx::contains(r.instructions, "two"));
    EXPECT(not migraphx::contains(r.instructions, "sum"));
    EXPECT(not migraphx::contains(r.instructions, "pass"));
    EXPECT(bool{r.result == child->end()});
}

TEST_CASE(match_bind_modules2)
{
    migraphx::program p;
    auto* mm    = p.get_main_module();
    auto one    = mm->add_literal(1);
    auto* child = p.create_module("child");
    auto two    = child->add_literal(2);
    auto sum    = child->add_instruction(sum_op{}, one, two);
    auto pass   = child->add_instruction(pass_op{}, sum);
    mm->add_instruction(mod_pass_op{}, {one}, {child});
    auto m = match::name("pass")(
                 match::args(match::name("sum")(match::args(match::name("@literal"),
                                                            match::name("@literal").bind("two")))
                                 .bind("sum")),
                 match::standard_shape())
                 .bind("pass");
    auto r = find_match(*child, m);
    EXPECT(bool{r.instructions.at("two") == two});
    EXPECT(bool{r.instructions.at("sum") == sum});
    EXPECT(bool{r.instructions.at("pass") == pass});
    EXPECT(bool{r.result == pass});
}
803

804
805
806
807
808
809
810
811
TEST_CASE(match_has_value1)
{
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
812
    auto m = match::has_value(1);
813
    auto r = find_match(mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
814
815
816
817
818
    EXPECT(bool{r.result == one});
}

TEST_CASE(match_has_value2)
{
819
820
821
822
823
824
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
825
    auto m = match::has_value(2);
826
    auto r = find_match(mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
827
828
829
830
831
    EXPECT(bool{r.result == two});
}

TEST_CASE(match_has_value3)
{
832
833
834
835
836
837
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
838
    auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(2)));
839
    auto r = find_match(mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
840
841
842
843
844
    EXPECT(bool{r.result == sum1});
}

TEST_CASE(match_has_value4)
{
845
846
847
848
849
850
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
851
    auto m = match::has_value(3);
852
853
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul Fultz II's avatar
Paul Fultz II committed
854
855
856
857
}

TEST_CASE(match_has_value5)
{
858
859
860
861
862
863
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
864
    auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(3)));
865
866
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul Fultz II's avatar
Paul Fultz II committed
867
868
869
870
}

TEST_CASE(match_has_value6)
{
871
872
873
874
875
876
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
877
    auto m = match::name("sum")(match::args(match::has_value(2), match::has_value(1)));
878
879
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul Fultz II's avatar
Paul Fultz II committed
880
881
882
883
}

TEST_CASE(match_tree1)
{
884
885
886
887
888
889
890
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, one, two);
    auto sum2  = mm.add_instruction(sum_op{}, sum1, three);
    mm.add_instruction(pass_op{}, sum2);
891
892
    auto m = match::tree(
        match::name("sum"), match::has_value(1), match::has_value(2), match::has_value(3));
893
    auto r = find_match(mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
894
895
896
897
898
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_tree2)
{
899
900
901
902
903
904
905
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, one, two);
    auto sum2  = mm.add_instruction(sum_op{}, sum1, three);
    mm.add_instruction(pass_op{}, sum2);
906
907
    auto m = match::tree(
        match::name("sum"), match::has_value(2), match::has_value(1), match::has_value(3));
908
909
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul Fultz II's avatar
Paul Fultz II committed
910
911
912
913
}

TEST_CASE(match_tree3)
{
914
915
916
917
918
919
920
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, one, two);
    auto sum2  = mm.add_instruction(sum_op{}, three, sum1);
    mm.add_instruction(pass_op{}, sum2);
921
922
    auto m = match::tree(
        match::name("sum"), match::has_value(3), match::has_value(1), match::has_value(2));
923
    auto r = find_match(mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
924
925
926
927
928
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_tree4)
{
929
930
931
932
933
934
935
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, one, two);
    auto sum2  = mm.add_instruction(sum_op{}, sum1, three);
    mm.add_instruction(pass_op{}, sum2);
936
937
938
939
940
    auto m = match::tree(match::name("sum"),
                         match::has_value(1),
                         match::has_value(2),
                         match::has_value(3),
                         match::has_value(4));
941
942
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul Fultz II's avatar
Paul Fultz II committed
943
944
945
946
}

TEST_CASE(match_tree5)
{
947
948
949
950
951
952
953
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, one, two);
    auto sum2  = mm.add_instruction(sum_op{}, sum1, three);
    mm.add_instruction(pass_op{}, sum2);
954
    auto m = match::tree(match::name("sum"), match::has_value(2), match::has_value(3));
955
956
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul Fultz II's avatar
Paul Fultz II committed
957
958
959
960
}

TEST_CASE(match_tree6)
{
961
962
963
964
965
966
967
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, one, two);
    auto sum2  = mm.add_instruction(sum_op{}, sum1, three);
    mm.add_instruction(pass_op{}, sum2);
968
    auto m = match::tree(match::name("sum"), match::has_value(1), match::has_value(3));
969
970
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul Fultz II's avatar
Paul Fultz II committed
971
972
973
974
}

TEST_CASE(match_unordered_tree1)
{
975
976
977
978
979
980
981
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, one, two);
    auto sum2  = mm.add_instruction(sum_op{}, sum1, three);
    mm.add_instruction(pass_op{}, sum2);
982
983
    auto m = match::unordered_tree(
        match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
984
    auto r = find_match(mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
985
986
987
988
989
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_unordered_tree2)
{
990
991
992
993
994
995
996
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, one, two);
    auto sum2  = mm.add_instruction(sum_op{}, three, sum1);
    mm.add_instruction(pass_op{}, sum2);
997
998
    auto m = match::unordered_tree(
        match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
999
    auto r = find_match(mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
1000
1001
1002
1003
1004
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_unordered_tree3)
{
1005
1006
1007
1008
1009
1010
1011
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, two, one);
    auto sum2  = mm.add_instruction(sum_op{}, sum1, three);
    mm.add_instruction(pass_op{}, sum2);
1012
1013
    auto m = match::unordered_tree(
        match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
1014
    auto r = find_match(mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
1015
1016
1017
1018
1019
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_unordered_tree4)
{
1020
1021
1022
1023
1024
1025
1026
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, one, two);
    auto sum2  = mm.add_instruction(sum_op{}, sum1, three);
    mm.add_instruction(pass_op{}, sum2);
1027
1028
    auto m = match::unordered_tree(
        match::name("sum"), match::has_value(4), match::has_value(2), match::has_value(1));
1029
1030
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul Fultz II's avatar
Paul Fultz II committed
1031
1032
}

Paul's avatar
Paul committed
1033
1034
struct match_find_sum
{
Paul's avatar
Paul committed
1035
    migraphx::instruction_ref ins;
Paul's avatar
Paul committed
1036
    auto matcher() const { return match::name("sum"); }
Paul's avatar
Paul committed
1037

1038
    void apply(migraphx::module&, const match::matcher_result& r) const
Paul's avatar
Paul committed
1039
1040
1041
    {
        EXPECT(bool{r.result == ins});
    }
Paul's avatar
Paul committed
1042
1043
1044
1045
};

struct match_find_literal
{
Paul's avatar
Paul committed
1046
    migraphx::instruction_ref ins;
Paul's avatar
Paul committed
1047
    auto matcher() const { return match::name("@literal"); }
Paul's avatar
Paul committed
1048

1049
    void apply(migraphx::module&, const match::matcher_result& r) const
Paul's avatar
Paul committed
1050
1051
1052
1053
1054
1055
    {
        EXPECT(bool{r.result != ins});
        EXPECT(r.result->name() == "@literal");
    }
};

Paul's avatar
Paul committed
1056
TEST_CASE(match_finder)
Paul's avatar
Paul committed
1057
{
1058
1059
1060
1061
1062
1063
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
    match::find_matches(mm, match_find_sum{sum}, match_find_literal{sum});
Paul's avatar
Paul committed
1064
1065
}

Paul's avatar
Paul committed
1066
int main(int argc, const char* argv[]) { test::run(argc, argv); }