matcher.cpp 35.9 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/matcher.hpp>
#include <migraphx/iterator_for.hpp>
Paul's avatar
Paul committed
3
4
5
#include <test.hpp>
#include <basic_ops.hpp>

Paul's avatar
Paul committed
6
namespace match = migraphx::match;
Paul's avatar
Paul committed
7

Paul's avatar
Paul committed
8
9
MIGRAPHX_PRED_MATCHER(throws, migraphx::instruction_ref) { MIGRAPHX_THROW("Matcher throws"); }

Paul's avatar
Paul committed
10
template <class M>
11
migraphx::match::matcher_result find_match(migraphx::module& modl, M&& m)
Paul's avatar
Paul committed
12
{
Paul's avatar
Paul committed
13
    migraphx::match::matcher_result result;
14
    for(auto ins : migraphx::iterator_for(modl))
Paul's avatar
Paul committed
15
    {
16
17
        result = migraphx::match::match_instruction(modl, ins, m);
        if(result.result != modl.end())
Paul's avatar
Paul committed
18
19
20
21
22
23
24
            return result;
    }
    return result;
}

void match1()
{
25
26
27
28
    migraphx::module mm;
    auto l = mm.add_literal(1);
    auto m = match::standard_shape();
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
29
30
31
    EXPECT(bool{r.result == l});
}

Paul's avatar
Paul committed
32
TEST_CASE(match_name1)
Paul's avatar
Paul committed
33
{
34
35
36
37
38
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
39
    auto m = match::name("sum");
40
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
41
42
43
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
44
TEST_CASE(match_name2)
Paul's avatar
Paul committed
45
{
46
47
48
49
50
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
51
    auto m = match::name("min");
52
53
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
54
55
}

Paul's avatar
Paul committed
56
TEST_CASE(match_name3)
Paul's avatar
Paul committed
57
{
58
59
60
61
62
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
63
    auto m = match::name("sum")(match::standard_shape());
64
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
65
66
67
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
68
TEST_CASE(match_arg1)
Paul's avatar
Paul committed
69
{
70
71
72
73
74
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
75
    auto m = match::name("sum")(match::arg(0)(match::name("@literal")), match::standard_shape());
76
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
77
78
79
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
80
TEST_CASE(match_arg2)
Paul's avatar
Paul committed
81
{
82
83
84
85
86
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
87
    auto m = match::name("sum")(match::arg(0)(match::name("sum")), match::standard_shape());
88
89
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
90
91
}

Paul's avatar
Paul committed
92
TEST_CASE(match_arg3)
Paul's avatar
Paul committed
93
{
94
95
96
97
98
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
99
    auto m = match::name("sum")(match::arg(1)(match::name("@literal")), match::standard_shape());
100
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
101
102
103
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
104
TEST_CASE(match_arg4)
Paul's avatar
Paul committed
105
{
106
107
108
109
110
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum  = mm.add_instruction(sum_op{}, one, two);
    auto pass = mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
111
    auto m    = match::name("pass")(match::arg(0)(match::name("sum")), match::standard_shape());
112
    auto r    = find_match(mm, m);
Paul's avatar
Paul committed
113
114
115
    EXPECT(bool{r.result == pass});
}

Paul's avatar
Paul committed
116
TEST_CASE(match_arg5)
Paul's avatar
Paul committed
117
{
118
119
120
121
122
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
123
    auto m = match::name("pass")(match::arg(1)(match::name("sum")), match::standard_shape());
124
125
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
126
127
}

Paul's avatar
Paul committed
128
TEST_CASE(match_arg6)
Paul's avatar
Paul committed
129
{
130
131
132
133
134
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
135
    auto m = match::name("sum")(match::arg(0)(match::name("@literal")));
136
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
137
138
139
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
140
TEST_CASE(match_arg7)
Paul's avatar
Paul committed
141
{
142
143
144
145
146
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
147
    auto m = match::name("sum")(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
148
                                match::arg(1)(match::name("@literal")));
149
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
150
151
152
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
153
154
TEST_CASE(match_arg8)
{
155
156
157
158
159
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
160
    auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
161
                                              match::arg(1)(match::name("@literal"))),
Paul's avatar
Paul committed
162
                                match::standard_shape());
163
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
164
165
166
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
167
168
TEST_CASE(match_nargs1)
{
169
170
171
172
173
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
174
    auto m = match::name("sum")(match::nargs(2));
175
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
176
177
178
179
180
    EXPECT(bool{r.result == sum});
}

TEST_CASE(match_nargs2)
{
181
182
183
184
185
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
186
    auto m = match::name("sum")(match::nargs(2), match::standard_shape());
187
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
188
189
190
191
192
    EXPECT(bool{r.result == sum});
}

TEST_CASE(match_nargs3)
{
193
194
195
196
197
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
198
    auto m = match::name("sum")(match::all_of(match::nargs(2)));
199
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
200
201
202
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
203
TEST_CASE(match_args1)
Paul's avatar
Paul committed
204
{
205
206
207
208
209
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
210
211
    auto m = match::name("sum")(match::args(match::name("@literal"), match::name("@literal")),
                                match::standard_shape());
212
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
213
214
215
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
216
TEST_CASE(match_args2)
Paul's avatar
Paul committed
217
{
218
219
220
221
222
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
223
224
    auto m = match::name("sum")(match::args(match::name("@literal"), match::name("sum")),
                                match::standard_shape());
225
226
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
227
228
}

Paul's avatar
Paul committed
229
TEST_CASE(match_args3)
Paul's avatar
Paul committed
230
{
231
232
233
234
235
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
236
    auto m = match::name("sum")(match::args(match::name("@literal")), match::standard_shape());
237
238
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
239
240
}

Paul's avatar
Paul committed
241
TEST_CASE(match_args4)
Paul's avatar
Paul committed
242
{
243
244
245
246
247
248
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
249
250
    auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
                                match::standard_shape());
251
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
252
253
254
    EXPECT(bool{r.result == sum2});
}

Paul's avatar
Paul committed
255
TEST_CASE(match_args5)
Paul's avatar
Paul committed
256
{
257
258
259
260
261
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
262
263
    auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
                                match::standard_shape());
264
265
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
266
267
}

Paul's avatar
Paul committed
268
TEST_CASE(match_args6)
Paul's avatar
Paul committed
269
{
270
271
272
273
274
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum  = mm.add_instruction(sum_op{}, one, two);
    auto pass = mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
275
    auto m    = match::name("pass")(match::args(match::name("sum")), match::standard_shape());
276
    auto r    = find_match(mm, m);
Paul's avatar
Paul committed
277
278
279
    EXPECT(bool{r.result == pass});
}

Paul's avatar
Paul committed
280
TEST_CASE(match_args7)
Paul's avatar
Paul committed
281
{
282
283
284
285
286
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum  = mm.add_instruction(sum_op{}, one, two);
    auto pass = mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
287
    auto m    = match::name("pass")(match::args(match::name("sum")(match::args(
Paul's avatar
Paul committed
288
289
                                     match::name("@literal"), match::name("@literal")))),
                                 match::standard_shape());
290
    auto r    = find_match(mm, m);
Paul's avatar
Paul committed
291
292
293
    EXPECT(bool{r.result == pass});
}

Paul's avatar
Paul committed
294
TEST_CASE(match_either_args1)
Paul's avatar
Paul committed
295
{
296
297
298
299
300
301
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
302
303
    auto m =
        match::name("sum")(match::either_arg(0, 1)(match::name("sum"), match::name("@literal")));
304
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
305
306
307
    EXPECT(bool{r.result == sum2});
}

Paul's avatar
Paul committed
308
TEST_CASE(match_either_args2)
Paul's avatar
Paul committed
309
{
310
311
312
313
314
315
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
316
317
    auto m =
        match::name("sum")(match::either_arg(0, 1)(match::name("@literal"), match::name("sum")));
318
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
319
320
321
    EXPECT(bool{r.result == sum2});
}

Paul's avatar
Paul committed
322
TEST_CASE(match_either_args3)
Paul's avatar
Paul committed
323
{
324
325
326
327
328
329
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
330
331
    auto m =
        match::name("sum")(match::either_arg(0, 1)(match::name("pass"), match::name("@literal")));
332
333
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
334
335
}

Paul's avatar
Paul committed
336
337
TEST_CASE(match_either_args_any1)
{
338
339
340
341
342
343
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
344
345
    auto m =
        match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y")));
346
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
347
348
349
350
351
352
    EXPECT(bool{r.result == sum1});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

TEST_CASE(match_either_args_any2)
{
353
354
355
356
357
358
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
359
360
    auto m = match::name("sum")(
        match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y")));
361
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
362
363
364
365
366
367
    EXPECT(bool{r.result == sum1});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

TEST_CASE(match_either_args_any3)
{
368
369
370
371
372
373
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
374
375
    auto m = match::name("sum")(
        match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y")));
376
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
377
378
379
380
381
382
    EXPECT(bool{r.result == sum1});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

TEST_CASE(match_either_args_any4)
{
383
384
385
386
387
388
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
389
390
    auto m = match::name("sum")(
        match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y")));
391
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
392
393
394
395
396
397
    EXPECT(bool{r.result == sum2});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

TEST_CASE(match_either_args_any5)
{
398
399
400
401
402
403
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
404
405
    auto m = match::name("sum")(
        match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y")));
406
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
407
408
409
410
    EXPECT(bool{r.result == sum2});
    EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}

Paul's avatar
Paul committed
411
TEST_CASE(match_all_of1)
Paul's avatar
Paul committed
412
{
413
414
415
416
417
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
418
    auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
419
                                              match::arg(1)(match::name("@literal"))));
420
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
421
422
423
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
424
TEST_CASE(match_all_of2)
Paul's avatar
Paul committed
425
{
426
427
428
429
430
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
431
432
    auto m = match::name("sum")(
        match::all_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
433
434
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
435
436
}

Paul's avatar
Paul committed
437
438
TEST_CASE(match_all_of3)
{
439
440
441
442
443
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
444
445
    auto m = match::name("sum")(match::all_of(match::all_of(
        match::arg(0)(match::name("@literal")), match::arg(1)(match::name("@literal")))));
446
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
447
448
449
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
450
451
TEST_CASE(match_lazy_any_of)
{
452
453
454
    migraphx::module mm;
    auto one = mm.add_literal(1);
    mm.add_instruction(pass_op{}, one);
Paul's avatar
Paul committed
455
    auto m = match::any_of(match::any(), throws());
456
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
457
458
459
460
461
    EXPECT(bool{r.result == one});
}

TEST_CASE(match_lazy_all_of)
{
462
463
464
    migraphx::module mm;
    auto one = mm.add_literal(1);
    mm.add_instruction(pass_op{}, one);
Paul's avatar
Paul committed
465
    auto m = match::all_of(match::none(), throws());
466
467
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
468
469
470
471
}

TEST_CASE(match_lazy_none_of)
{
472
473
474
    migraphx::module mm;
    auto one = mm.add_literal(1);
    mm.add_instruction(pass_op{}, one);
Paul's avatar
Paul committed
475
    auto m = match::none_of(match::any(), throws());
476
477
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
478
479
}

Paul's avatar
Paul committed
480
TEST_CASE(match_any_of1)
Paul's avatar
Paul committed
481
{
482
483
484
485
486
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
487
488
    auto m = match::name("sum")(
        match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
489
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
490
491
492
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
493
TEST_CASE(match_any_of2)
Paul's avatar
Paul committed
494
{
495
496
497
498
499
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
500
501
    auto m = match::name("sum")(
        match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
502
503
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
504
505
}

Paul's avatar
Paul committed
506
507
TEST_CASE(match_any_of_lazy1)
{
508
509
510
511
512
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
513
    auto m = match::name("sum")(
Paul's avatar
Paul committed
514
515
        match::any_of(match::args(match::any(), match::any()).bind("x"),
                      match::args(match::name("sum"), match::name("sum")).bind("y")));
516
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
517
518
519
520
521
522
523
524
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x"));
    EXPECT(bool{r.instructions["x"] == sum});
    EXPECT(not migraphx::contains(r.instructions, "y"));
}

TEST_CASE(match_any_of_lazy2)
{
525
526
527
528
529
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
530
    auto m = match::name("sum")(
Paul's avatar
Paul committed
531
532
        match::any_of(match::args(match::name("@literal"), match::name("@literal")).bind("x"),
                      match::args(match::any(), match::any()).bind("y")));
533
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
534
535
536
537
538
539
540
541
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x"));
    EXPECT(bool{r.instructions["x"] == sum});
    EXPECT(not migraphx::contains(r.instructions, "y"));
}

TEST_CASE(match_any_of_lazy3)
{
542
543
544
545
546
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
547
    auto m = match::name("sum")(
Paul's avatar
Paul committed
548
549
        match::any_of(match::args(match::any(), match::any()).bind("x"),
                      match::args(match::name("@literal"), match::name("@literal")).bind("y")));
550
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
551
552
553
554
555
556
557
558
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x"));
    EXPECT(bool{r.instructions["x"] == sum});
    EXPECT(not migraphx::contains(r.instructions, "y"));
}

TEST_CASE(match_any_of_lazy4)
{
559
560
561
562
563
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
564
565
566
    auto m = match::name("sum")(match::any_of(
        match::args(match::name("@literal").bind("x1"), match::name("@literal").bind("y1")),
        match::args(match::any().bind("x2"), match::any().bind("y2"))));
567
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
568
569
570
571
572
573
574
575
576
577
578
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x1"));
    EXPECT(migraphx::contains(r.instructions, "y1"));
    EXPECT(bool{r.instructions["x1"] == one});
    EXPECT(bool{r.instructions["y1"] == two});
    EXPECT(not migraphx::contains(r.instructions, "x2"));
    EXPECT(not migraphx::contains(r.instructions, "y2"));
}

TEST_CASE(match_any_of_lazy5)
{
579
580
581
582
583
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
584
585
586
    auto m = match::name("sum")(match::any_of(
        match::args(match::any().bind("x1"), match::any().bind("y1")),
        match::args(match::name("@literal").bind("x2"), match::name("@literal").bind("y2"))));
587
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
588
589
590
591
592
593
594
595
596
    EXPECT(bool{r.result == sum});
    EXPECT(migraphx::contains(r.instructions, "x1"));
    EXPECT(migraphx::contains(r.instructions, "y1"));
    EXPECT(bool{r.instructions["x1"] == one});
    EXPECT(bool{r.instructions["y1"] == two});
    EXPECT(not migraphx::contains(r.instructions, "x2"));
    EXPECT(not migraphx::contains(r.instructions, "y2"));
}

Paul's avatar
Paul committed
597
TEST_CASE(match_none_of1)
Paul's avatar
Paul committed
598
{
599
600
601
602
603
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
604
605
    auto m = match::name("sum")(
        match::none_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
606
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
607
608
609
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
610
TEST_CASE(match_none_of2)
Paul's avatar
Paul committed
611
{
612
613
614
615
616
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
617
    auto m = match::name("sum")(match::none_of(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
618
                                               match::arg(1)(match::name("@literal"))));
619
620
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
621
622
}

Paul's avatar
Paul committed
623
624
TEST_CASE(match_output1)
{
625
626
627
628
629
630
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto minus = mm.add_instruction(minus_op{}, two, one);
    auto sum   = mm.add_instruction(sum_op{}, minus, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
631
    auto m = match::name("minus")(match::output(match::name("sum")));
632
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
633
634
635
636
637
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_output2)
{
638
639
640
641
642
643
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto minus = mm.add_instruction(minus_op{}, two, one);
    auto sum   = mm.add_instruction(sum_op{}, minus, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
644
    auto m = match::name("@literal")(match::output(match::name("sum")));
645
646
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
647
648
649
650
}

TEST_CASE(match_skip_output1)
{
651
652
653
654
655
656
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto minus = mm.add_instruction(minus_op{}, two, one);
    auto sum   = mm.add_instruction(sum_op{}, minus, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
657
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
658
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
659
660
661
662
663
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output2)
{
664
665
666
667
668
669
670
    migraphx::module mm;
    auto one        = mm.add_literal(1);
    auto two        = mm.add_literal(2);
    auto minus      = mm.add_instruction(minus_op{}, two, one);
    auto minus_pass = mm.add_instruction(pass_op{}, minus);
    auto sum        = mm.add_instruction(sum_op{}, minus_pass, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
671
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
672
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
673
674
675
676
677
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output3)
{
678
679
680
681
682
683
684
685
686
    migraphx::module mm;
    auto one         = mm.add_literal(1);
    auto two         = mm.add_literal(2);
    auto minus       = mm.add_instruction(minus_op{}, two, one);
    auto minus_pass1 = mm.add_instruction(pass_op{}, minus);
    auto minus_pass2 = mm.add_instruction(pass_op{}, minus_pass1);
    auto minus_pass3 = mm.add_instruction(pass_op{}, minus_pass2);
    auto sum         = mm.add_instruction(sum_op{}, minus_pass3, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
687
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
688
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
689
690
691
692
693
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output4)
{
694
695
696
697
698
699
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto pass = mm.add_instruction(pass_op{}, one);
    auto sum  = mm.add_instruction(sum_op{}, pass, two);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
700
    auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
701
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
702
703
704
705
706
    EXPECT(bool{r.result == two});
}

TEST_CASE(match_skip_output5)
{
707
708
709
710
711
712
713
714
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto pass = mm.add_instruction(pass_op{}, one);
    auto sum1 = mm.add_instruction(sum_op{}, pass, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, one);
    auto sum3 = mm.add_instruction(sum_op{}, sum2, two);
    mm.add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
715
    auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
716
717
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul's avatar
Paul committed
718
719
720
721
}

TEST_CASE(match_skip_output6)
{
722
723
724
725
726
727
728
729
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto minus = mm.add_instruction(minus_op{}, two, one);
    auto sum1  = mm.add_instruction(sum_op{}, minus, two);
    auto sum2  = mm.add_instruction(sum_op{}, sum1, one);
    auto sum3  = mm.add_instruction(sum_op{}, sum2, two);
    mm.add_instruction(pass_op{}, sum3);
Paul's avatar
Paul committed
730
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
731
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
732
733
734
735
736
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output7)
{
737
738
739
740
741
742
743
    migraphx::module mm;
    auto one    = mm.add_literal(1);
    auto two    = mm.add_literal(2);
    auto minus1 = mm.add_instruction(minus_op{}, two, one);
    auto minus2 = mm.add_instruction(minus_op{}, two, minus1);
    auto sum    = mm.add_instruction(sum_op{}, one, minus2);
    mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
744
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("minus")));
745
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
746
747
748
    EXPECT(bool{r.result == minus1});
}

Paul's avatar
Paul committed
749
TEST_CASE(match_bind1)
Paul's avatar
Paul committed
750
{
751
752
753
754
755
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum  = mm.add_instruction(sum_op{}, one, two);
    auto pass = mm.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
756
    auto m    = match::name("pass")(
Paul's avatar
Paul committed
757
758
759
                 match::args(match::name("sum")(match::args(match::name("@literal").bind("one"),
                                                            match::name("@literal").bind("two")))
                                 .bind("sum")),
Paul's avatar
Paul committed
760
                 match::standard_shape())
Paul's avatar
Paul committed
761
                 .bind("pass");
762
    auto r = find_match(mm, m);
Paul's avatar
Paul committed
763
764
765
766
767
768
769
    EXPECT(bool{r.instructions.at("one") == one});
    EXPECT(bool{r.instructions.at("two") == two});
    EXPECT(bool{r.instructions.at("sum") == sum});
    EXPECT(bool{r.instructions.at("pass") == pass});
    EXPECT(bool{r.result == pass});
}

770
TEST_CASE(match_bind_modules1)
Paul Fultz II's avatar
Paul Fultz II committed
771
772
{
    migraphx::program p;
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
    auto* mm    = p.get_main_module();
    auto one    = mm->add_literal(1);
    auto* child = p.create_module("child");
    auto two    = child->add_literal(2);
    auto sum    = child->add_instruction(sum_op{}, one, two);
    child->add_instruction(pass_op{}, sum);
    mm->add_instruction(mod_pass_op{}, {one}, {child});
    auto m = match::name("pass")(
                 match::args(match::name("sum")(match::args(match::name("@literal").bind("one"),
                                                            match::name("@literal").bind("two")))
                                 .bind("sum")),
                 match::standard_shape())
                 .bind("pass");
    auto r = find_match(*child, m);
    EXPECT(not migraphx::contains(r.instructions, "one"));
    EXPECT(not migraphx::contains(r.instructions, "two"));
    EXPECT(not migraphx::contains(r.instructions, "sum"));
    EXPECT(not migraphx::contains(r.instructions, "pass"));
    EXPECT(bool{r.result == child->end()});
}

TEST_CASE(match_bind_modules2)
{
    migraphx::program p;
    auto* mm    = p.get_main_module();
    auto one    = mm->add_literal(1);
    auto* child = p.create_module("child");
    auto two    = child->add_literal(2);
    auto sum    = child->add_instruction(sum_op{}, one, two);
    auto pass   = child->add_instruction(pass_op{}, sum);
    mm->add_instruction(mod_pass_op{}, {one}, {child});
    auto m = match::name("pass")(
                 match::args(match::name("sum")(match::args(match::name("@literal"),
                                                            match::name("@literal").bind("two")))
                                 .bind("sum")),
                 match::standard_shape())
                 .bind("pass");
    auto r = find_match(*child, m);
    EXPECT(bool{r.instructions.at("two") == two});
    EXPECT(bool{r.instructions.at("sum") == sum});
    EXPECT(bool{r.instructions.at("pass") == pass});
    EXPECT(bool{r.result == pass});
}
816

817
818
819
820
821
822
823
824
TEST_CASE(match_has_value1)
{
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
825
    auto m = match::has_value(1);
826
    auto r = find_match(mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
827
828
829
830
831
    EXPECT(bool{r.result == one});
}

TEST_CASE(match_has_value2)
{
832
833
834
835
836
837
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
838
    auto m = match::has_value(2);
839
    auto r = find_match(mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
840
841
842
843
844
    EXPECT(bool{r.result == two});
}

TEST_CASE(match_has_value3)
{
845
846
847
848
849
850
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
851
    auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(2)));
852
    auto r = find_match(mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
853
854
855
856
857
    EXPECT(bool{r.result == sum1});
}

TEST_CASE(match_has_value4)
{
858
859
860
861
862
863
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
864
    auto m = match::has_value(3);
865
866
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul Fultz II's avatar
Paul Fultz II committed
867
868
869
870
}

TEST_CASE(match_has_value5)
{
871
872
873
874
875
876
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
877
    auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(3)));
878
879
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul Fultz II's avatar
Paul Fultz II committed
880
881
882
883
}

TEST_CASE(match_has_value6)
{
884
885
886
887
888
889
    migraphx::module mm;
    auto one  = mm.add_literal(1);
    auto two  = mm.add_literal(2);
    auto sum1 = mm.add_instruction(sum_op{}, one, two);
    auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
    mm.add_instruction(pass_op{}, sum2);
Paul Fultz II's avatar
Paul Fultz II committed
890
    auto m = match::name("sum")(match::args(match::has_value(2), match::has_value(1)));
891
892
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul Fultz II's avatar
Paul Fultz II committed
893
894
895
896
}

TEST_CASE(match_tree1)
{
897
898
899
900
901
902
903
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, one, two);
    auto sum2  = mm.add_instruction(sum_op{}, sum1, three);
    mm.add_instruction(pass_op{}, sum2);
904
905
    auto m = match::tree(
        match::name("sum"), match::has_value(1), match::has_value(2), match::has_value(3));
906
    auto r = find_match(mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
907
908
909
910
911
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_tree2)
{
912
913
914
915
916
917
918
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, one, two);
    auto sum2  = mm.add_instruction(sum_op{}, sum1, three);
    mm.add_instruction(pass_op{}, sum2);
919
920
    auto m = match::tree(
        match::name("sum"), match::has_value(2), match::has_value(1), match::has_value(3));
921
922
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul Fultz II's avatar
Paul Fultz II committed
923
924
925
926
}

TEST_CASE(match_tree3)
{
927
928
929
930
931
932
933
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, one, two);
    auto sum2  = mm.add_instruction(sum_op{}, three, sum1);
    mm.add_instruction(pass_op{}, sum2);
934
935
    auto m = match::tree(
        match::name("sum"), match::has_value(3), match::has_value(1), match::has_value(2));
936
    auto r = find_match(mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
937
938
939
940
941
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_tree4)
{
942
943
944
945
946
947
948
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, one, two);
    auto sum2  = mm.add_instruction(sum_op{}, sum1, three);
    mm.add_instruction(pass_op{}, sum2);
949
950
951
952
953
    auto m = match::tree(match::name("sum"),
                         match::has_value(1),
                         match::has_value(2),
                         match::has_value(3),
                         match::has_value(4));
954
955
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul Fultz II's avatar
Paul Fultz II committed
956
957
958
959
}

TEST_CASE(match_tree5)
{
960
961
962
963
964
965
966
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, one, two);
    auto sum2  = mm.add_instruction(sum_op{}, sum1, three);
    mm.add_instruction(pass_op{}, sum2);
967
    auto m = match::tree(match::name("sum"), match::has_value(2), match::has_value(3));
968
969
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul Fultz II's avatar
Paul Fultz II committed
970
971
972
973
}

TEST_CASE(match_tree6)
{
974
975
976
977
978
979
980
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, one, two);
    auto sum2  = mm.add_instruction(sum_op{}, sum1, three);
    mm.add_instruction(pass_op{}, sum2);
981
    auto m = match::tree(match::name("sum"), match::has_value(1), match::has_value(3));
982
983
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul Fultz II's avatar
Paul Fultz II committed
984
985
986
987
}

TEST_CASE(match_unordered_tree1)
{
988
989
990
991
992
993
994
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, one, two);
    auto sum2  = mm.add_instruction(sum_op{}, sum1, three);
    mm.add_instruction(pass_op{}, sum2);
995
996
    auto m = match::unordered_tree(
        match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
997
    auto r = find_match(mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
998
999
1000
1001
1002
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_unordered_tree2)
{
1003
1004
1005
1006
1007
1008
1009
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, one, two);
    auto sum2  = mm.add_instruction(sum_op{}, three, sum1);
    mm.add_instruction(pass_op{}, sum2);
1010
1011
    auto m = match::unordered_tree(
        match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
1012
    auto r = find_match(mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
1013
1014
1015
1016
1017
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_unordered_tree3)
{
1018
1019
1020
1021
1022
1023
1024
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, two, one);
    auto sum2  = mm.add_instruction(sum_op{}, sum1, three);
    mm.add_instruction(pass_op{}, sum2);
1025
1026
    auto m = match::unordered_tree(
        match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
1027
    auto r = find_match(mm, m);
Paul Fultz II's avatar
Paul Fultz II committed
1028
1029
1030
1031
1032
    EXPECT(bool{r.result == sum2});
}

TEST_CASE(match_unordered_tree4)
{
1033
1034
1035
1036
1037
1038
1039
    migraphx::module mm;
    auto one   = mm.add_literal(1);
    auto two   = mm.add_literal(2);
    auto three = mm.add_literal(3);
    auto sum1  = mm.add_instruction(sum_op{}, one, two);
    auto sum2  = mm.add_instruction(sum_op{}, sum1, three);
    mm.add_instruction(pass_op{}, sum2);
1040
1041
    auto m = match::unordered_tree(
        match::name("sum"), match::has_value(4), match::has_value(2), match::has_value(1));
1042
1043
    auto r = find_match(mm, m);
    EXPECT(bool{r.result == mm.end()});
Paul Fultz II's avatar
Paul Fultz II committed
1044
1045
}

Paul's avatar
Paul committed
1046
1047
struct match_find_sum
{
Paul's avatar
Paul committed
1048
    migraphx::instruction_ref ins;
Paul's avatar
Paul committed
1049
    auto matcher() const { return match::name("sum"); }
Paul's avatar
Paul committed
1050

1051
    void apply(migraphx::module&, const match::matcher_result& r) const
Paul's avatar
Paul committed
1052
1053
1054
    {
        EXPECT(bool{r.result == ins});
    }
Paul's avatar
Paul committed
1055
1056
1057
1058
};

struct match_find_literal
{
Paul's avatar
Paul committed
1059
    migraphx::instruction_ref ins;
Paul's avatar
Paul committed
1060
    auto matcher() const { return match::name("@literal"); }
Paul's avatar
Paul committed
1061

1062
    void apply(migraphx::module&, const match::matcher_result& r) const
Paul's avatar
Paul committed
1063
1064
1065
1066
1067
1068
    {
        EXPECT(bool{r.result != ins});
        EXPECT(r.result->name() == "@literal");
    }
};

Paul's avatar
Paul committed
1069
TEST_CASE(match_finder)
Paul's avatar
Paul committed
1070
{
1071
1072
1073
1074
1075
1076
    migraphx::module mm;
    auto one = mm.add_literal(1);
    auto two = mm.add_literal(2);
    auto sum = mm.add_instruction(sum_op{}, one, two);
    mm.add_instruction(pass_op{}, sum);
    match::find_matches(mm, match_find_sum{sum}, match_find_literal{sum});
Paul's avatar
Paul committed
1077
1078
}

Paul's avatar
Paul committed
1079
int main(int argc, const char* argv[]) { test::run(argc, argv); }