matcher.cpp 19.2 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/matcher.hpp>
#include <migraphx/iterator_for.hpp>
Paul's avatar
Paul committed
3
4
5
#include <test.hpp>
#include <basic_ops.hpp>

Paul's avatar
Paul committed
6
namespace match = migraphx::match;
Paul's avatar
Paul committed
7

Paul's avatar
Paul committed
8
template <class M>
Paul's avatar
Paul committed
9
migraphx::match::matcher_result find_match(migraphx::program& p, M&& m)
Paul's avatar
Paul committed
10
{
Paul's avatar
Paul committed
11
12
    migraphx::match::matcher_result result;
    for(auto ins : migraphx::iterator_for(p))
Paul's avatar
Paul committed
13
    {
Paul's avatar
Paul committed
14
        result = migraphx::match::match_instruction(p, ins, m);
Paul's avatar
Paul committed
15
16
17
18
19
20
21
22
        if(result.result != p.end())
            return result;
    }
    return result;
}

void match1()
{
Paul's avatar
Paul committed
23
    migraphx::program p;
Paul's avatar
Paul committed
24
    auto l = p.add_literal(1);
Paul's avatar
Paul committed
25
    auto m = match::standard_shape();
Paul's avatar
Paul committed
26
27
28
29
    auto r = find_match(p, m);
    EXPECT(bool{r.result == l});
}

Paul's avatar
Paul committed
30
TEST_CASE(match_name1)
Paul's avatar
Paul committed
31
{
Paul's avatar
Paul committed
32
    migraphx::program p;
Paul's avatar
Paul committed
33
34
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
35
36
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
37
    auto m = match::name("sum");
Paul's avatar
Paul committed
38
39
40
41
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
42
TEST_CASE(match_name2)
Paul's avatar
Paul committed
43
{
Paul's avatar
Paul committed
44
    migraphx::program p;
Paul's avatar
Paul committed
45
46
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
47
48
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
49
    auto m = match::name("min");
Paul's avatar
Paul committed
50
51
52
53
    auto r = find_match(p, m);
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
54
TEST_CASE(match_name3)
Paul's avatar
Paul committed
55
{
Paul's avatar
Paul committed
56
    migraphx::program p;
Paul's avatar
Paul committed
57
58
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
59
60
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
61
    auto m = match::name("sum")(match::standard_shape());
Paul's avatar
Paul committed
62
63
64
65
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
66
TEST_CASE(match_arg1)
Paul's avatar
Paul committed
67
{
Paul's avatar
Paul committed
68
    migraphx::program p;
Paul's avatar
Paul committed
69
70
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
71
72
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
73
    auto m = match::name("sum")(match::arg(0)(match::name("@literal")), match::standard_shape());
Paul's avatar
Paul committed
74
75
76
77
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
78
TEST_CASE(match_arg2)
Paul's avatar
Paul committed
79
{
Paul's avatar
Paul committed
80
    migraphx::program p;
Paul's avatar
Paul committed
81
82
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
83
84
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
85
    auto m = match::name("sum")(match::arg(0)(match::name("sum")), match::standard_shape());
Paul's avatar
Paul committed
86
87
88
89
    auto r = find_match(p, m);
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
90
TEST_CASE(match_arg3)
Paul's avatar
Paul committed
91
{
Paul's avatar
Paul committed
92
    migraphx::program p;
Paul's avatar
Paul committed
93
94
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
95
96
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
97
    auto m = match::name("sum")(match::arg(1)(match::name("@literal")), match::standard_shape());
Paul's avatar
Paul committed
98
99
100
101
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
102
TEST_CASE(match_arg4)
Paul's avatar
Paul committed
103
{
Paul's avatar
Paul committed
104
    migraphx::program p;
Paul's avatar
Paul committed
105
106
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
107
    auto sum  = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
108
    auto pass = p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
109
110
    auto m    = match::name("pass")(match::arg(0)(match::name("sum")), match::standard_shape());
    auto r    = find_match(p, m);
Paul's avatar
Paul committed
111
112
113
    EXPECT(bool{r.result == pass});
}

Paul's avatar
Paul committed
114
TEST_CASE(match_arg5)
Paul's avatar
Paul committed
115
{
Paul's avatar
Paul committed
116
    migraphx::program p;
Paul's avatar
Paul committed
117
118
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
119
120
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
121
    auto m = match::name("pass")(match::arg(1)(match::name("sum")), match::standard_shape());
Paul's avatar
Paul committed
122
123
124
125
    auto r = find_match(p, m);
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
126
TEST_CASE(match_arg6)
Paul's avatar
Paul committed
127
{
Paul's avatar
Paul committed
128
    migraphx::program p;
Paul's avatar
Paul committed
129
130
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
131
132
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
133
    auto m = match::name("sum")(match::arg(0)(match::name("@literal")));
Paul's avatar
Paul committed
134
135
136
137
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
138
TEST_CASE(match_arg7)
Paul's avatar
Paul committed
139
{
Paul's avatar
Paul committed
140
    migraphx::program p;
Paul's avatar
Paul committed
141
142
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
143
144
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
145
    auto m = match::name("sum")(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
146
                                match::arg(1)(match::name("@literal")));
Paul's avatar
Paul committed
147
148
149
150
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
151
152
153
154
155
156
157
158
TEST_CASE(match_arg8)
{
    migraphx::program p;
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
    auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
159
                                              match::arg(1)(match::name("@literal"))),
Paul's avatar
Paul committed
160
161
162
163
164
                                match::standard_shape());
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
TEST_CASE(match_nargs1)
{
    migraphx::program p;
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
    auto m = match::name("sum")(match::nargs(2));
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

TEST_CASE(match_nargs2)
{
    migraphx::program p;
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
    auto m = match::name("sum")(match::nargs(2),
                                match::standard_shape());
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

TEST_CASE(match_nargs3)
{
    migraphx::program p;
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
    auto m = match::name("sum")(match::all_of(match::nargs(2)));
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
202
TEST_CASE(match_args1)
Paul's avatar
Paul committed
203
{
Paul's avatar
Paul committed
204
    migraphx::program p;
Paul's avatar
Paul committed
205
206
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
207
208
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
209
210
    auto m = match::name("sum")(match::args(match::name("@literal"), match::name("@literal")),
                                match::standard_shape());
Paul's avatar
Paul committed
211
212
213
214
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
215
TEST_CASE(match_args2)
Paul's avatar
Paul committed
216
{
Paul's avatar
Paul committed
217
    migraphx::program p;
Paul's avatar
Paul committed
218
219
220
221
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
222
223
    auto m = match::name("sum")(match::args(match::name("@literal"), match::name("sum")),
                                match::standard_shape());
Paul's avatar
Paul committed
224
225
226
227
    auto r = find_match(p, m);
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
228
TEST_CASE(match_args3)
Paul's avatar
Paul committed
229
{
Paul's avatar
Paul committed
230
    migraphx::program p;
Paul's avatar
Paul committed
231
232
233
234
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
235
    auto m = match::name("sum")(match::args(match::name("@literal")), match::standard_shape());
Paul's avatar
Paul committed
236
    auto r = find_match(p, m);
237
    EXPECT(bool{r.result == p.end()});
Paul's avatar
Paul committed
238
239
}

Paul's avatar
Paul committed
240
TEST_CASE(match_args4)
Paul's avatar
Paul committed
241
{
Paul's avatar
Paul committed
242
    migraphx::program p;
Paul's avatar
Paul committed
243
244
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
245
246
247
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    auto sum2 = p.add_instruction(sum_op{}, sum1, two);
    p.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
248
249
    auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
                                match::standard_shape());
Paul's avatar
Paul committed
250
251
252
253
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum2});
}

Paul's avatar
Paul committed
254
TEST_CASE(match_args5)
Paul's avatar
Paul committed
255
{
Paul's avatar
Paul committed
256
    migraphx::program p;
Paul's avatar
Paul committed
257
258
259
260
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
261
262
    auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
                                match::standard_shape());
Paul's avatar
Paul committed
263
264
265
266
    auto r = find_match(p, m);
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
267
TEST_CASE(match_args6)
Paul's avatar
Paul committed
268
{
Paul's avatar
Paul committed
269
    migraphx::program p;
Paul's avatar
Paul committed
270
271
272
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
    auto sum  = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
273
    auto pass = p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
274
275
    auto m    = match::name("pass")(match::args(match::name("sum")), match::standard_shape());
    auto r    = find_match(p, m);
Paul's avatar
Paul committed
276
277
278
    EXPECT(bool{r.result == pass});
}

Paul's avatar
Paul committed
279
TEST_CASE(match_args7)
Paul's avatar
Paul committed
280
{
Paul's avatar
Paul committed
281
    migraphx::program p;
Paul's avatar
Paul committed
282
283
284
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
    auto sum  = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
285
    auto pass = p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
286
    auto m    = match::name("pass")(match::args(match::name("sum")(match::args(
Paul's avatar
Paul committed
287
288
                                     match::name("@literal"), match::name("@literal")))),
                                 match::standard_shape());
Paul's avatar
Paul committed
289
    auto r    = find_match(p, m);
Paul's avatar
Paul committed
290
291
292
    EXPECT(bool{r.result == pass});
}

Paul's avatar
Paul committed
293
TEST_CASE(match_either_args1)
Paul's avatar
Paul committed
294
{
Paul's avatar
Paul committed
295
    migraphx::program p;
Paul's avatar
Paul committed
296
297
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
298
    auto sum1 = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
299
300
    auto sum2 = p.add_instruction(sum_op{}, sum1, two);
    p.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
301
302
303
    auto m =
        match::name("sum")(match::either_arg(0, 1)(match::name("sum"), match::name("@literal")));
    auto r = find_match(p, m);
Paul's avatar
Paul committed
304
305
306
    EXPECT(bool{r.result == sum2});
}

Paul's avatar
Paul committed
307
TEST_CASE(match_either_args2)
Paul's avatar
Paul committed
308
{
Paul's avatar
Paul committed
309
    migraphx::program p;
Paul's avatar
Paul committed
310
311
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
312
    auto sum1 = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
313
314
    auto sum2 = p.add_instruction(sum_op{}, sum1, two);
    p.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
315
316
317
    auto m =
        match::name("sum")(match::either_arg(0, 1)(match::name("@literal"), match::name("sum")));
    auto r = find_match(p, m);
Paul's avatar
Paul committed
318
319
320
    EXPECT(bool{r.result == sum2});
}

Paul's avatar
Paul committed
321
TEST_CASE(match_either_args3)
Paul's avatar
Paul committed
322
{
Paul's avatar
Paul committed
323
    migraphx::program p;
Paul's avatar
Paul committed
324
325
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
326
    auto sum1 = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
327
328
    auto sum2 = p.add_instruction(sum_op{}, sum1, two);
    p.add_instruction(pass_op{}, sum2);
Paul's avatar
Paul committed
329
330
331
    auto m =
        match::name("sum")(match::either_arg(0, 1)(match::name("pass"), match::name("@literal")));
    auto r = find_match(p, m);
Paul's avatar
Paul committed
332
333
334
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
335
TEST_CASE(match_all_of1)
Paul's avatar
Paul committed
336
{
Paul's avatar
Paul committed
337
    migraphx::program p;
Paul's avatar
Paul committed
338
339
340
341
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
342
    auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
343
                                              match::arg(1)(match::name("@literal"))));
Paul's avatar
Paul committed
344
345
346
347
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
348
TEST_CASE(match_all_of2)
Paul's avatar
Paul committed
349
{
Paul's avatar
Paul committed
350
    migraphx::program p;
Paul's avatar
Paul committed
351
352
353
354
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
355
356
    auto m = match::name("sum")(
        match::all_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
Paul's avatar
Paul committed
357
358
359
360
    auto r = find_match(p, m);
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
361
362
363
364
365
366
367
368
369
370
371
372
373
TEST_CASE(match_all_of3)
{
    migraphx::program p;
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
    auto m = match::name("sum")(match::all_of(match::all_of(match::arg(0)(match::name("@literal")),
                                              match::arg(1)(match::name("@literal")))));
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
374
TEST_CASE(match_any_of1)
Paul's avatar
Paul committed
375
{
Paul's avatar
Paul committed
376
    migraphx::program p;
Paul's avatar
Paul committed
377
378
379
380
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
381
382
    auto m = match::name("sum")(
        match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
Paul's avatar
Paul committed
383
384
385
386
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
387
TEST_CASE(match_any_of2)
Paul's avatar
Paul committed
388
{
Paul's avatar
Paul committed
389
    migraphx::program p;
Paul's avatar
Paul committed
390
391
392
393
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
394
395
    auto m = match::name("sum")(
        match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
Paul's avatar
Paul committed
396
397
398
399
    auto r = find_match(p, m);
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
400
TEST_CASE(match_none_of1)
Paul's avatar
Paul committed
401
{
Paul's avatar
Paul committed
402
    migraphx::program p;
Paul's avatar
Paul committed
403
404
405
406
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
407
408
    auto m = match::name("sum")(
        match::none_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
Paul's avatar
Paul committed
409
410
411
412
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
413
TEST_CASE(match_none_of2)
Paul's avatar
Paul committed
414
{
Paul's avatar
Paul committed
415
    migraphx::program p;
Paul's avatar
Paul committed
416
417
418
419
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
420
    auto m = match::name("sum")(match::none_of(match::arg(0)(match::name("@literal")),
Paul's avatar
Paul committed
421
                                               match::arg(1)(match::name("@literal"))));
Paul's avatar
Paul committed
422
423
424
425
    auto r = find_match(p, m);
    EXPECT(bool{r.result == p.end()});
}

Paul's avatar
Paul committed
426
427
428
TEST_CASE(match_output1)
{
    migraphx::program p;
Paul's avatar
Paul committed
429
430
    auto one   = p.add_literal(1);
    auto two   = p.add_literal(2);
Paul's avatar
Paul committed
431
    auto minus = p.add_instruction(minus_op{}, two, one);
Paul's avatar
Paul committed
432
    auto sum   = p.add_instruction(sum_op{}, minus, two);
Paul's avatar
Paul committed
433
434
435
436
437
438
439
440
441
    p.add_instruction(pass_op{}, sum);
    auto m = match::name("minus")(match::output(match::name("sum")));
    auto r = find_match(p, m);
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_output2)
{
    migraphx::program p;
Paul's avatar
Paul committed
442
443
    auto one   = p.add_literal(1);
    auto two   = p.add_literal(2);
Paul's avatar
Paul committed
444
    auto minus = p.add_instruction(minus_op{}, two, one);
Paul's avatar
Paul committed
445
    auto sum   = p.add_instruction(sum_op{}, minus, two);
Paul's avatar
Paul committed
446
447
448
449
450
451
452
453
454
    p.add_instruction(pass_op{}, sum);
    auto m = match::name("@literal")(match::output(match::name("sum")));
    auto r = find_match(p, m);
    EXPECT(bool{r.result == p.end()});
}

TEST_CASE(match_skip_output1)
{
    migraphx::program p;
Paul's avatar
Paul committed
455
456
    auto one   = p.add_literal(1);
    auto two   = p.add_literal(2);
Paul's avatar
Paul committed
457
    auto minus = p.add_instruction(minus_op{}, two, one);
Paul's avatar
Paul committed
458
    auto sum   = p.add_instruction(sum_op{}, minus, two);
Paul's avatar
Paul committed
459
460
461
462
463
464
465
466
467
    p.add_instruction(pass_op{}, sum);
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
    auto r = find_match(p, m);
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output2)
{
    migraphx::program p;
Paul's avatar
Paul committed
468
469
470
    auto one        = p.add_literal(1);
    auto two        = p.add_literal(2);
    auto minus      = p.add_instruction(minus_op{}, two, one);
Paul's avatar
Paul committed
471
    auto minus_pass = p.add_instruction(pass_op{}, minus);
Paul's avatar
Paul committed
472
    auto sum        = p.add_instruction(sum_op{}, minus_pass, two);
Paul's avatar
Paul committed
473
474
475
476
477
478
479
480
481
    p.add_instruction(pass_op{}, sum);
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
    auto r = find_match(p, m);
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output3)
{
    migraphx::program p;
Paul's avatar
Paul committed
482
483
484
    auto one         = p.add_literal(1);
    auto two         = p.add_literal(2);
    auto minus       = p.add_instruction(minus_op{}, two, one);
Paul's avatar
Paul committed
485
486
487
    auto minus_pass1 = p.add_instruction(pass_op{}, minus);
    auto minus_pass2 = p.add_instruction(pass_op{}, minus_pass1);
    auto minus_pass3 = p.add_instruction(pass_op{}, minus_pass2);
Paul's avatar
Paul committed
488
    auto sum         = p.add_instruction(sum_op{}, minus_pass3, two);
Paul's avatar
Paul committed
489
490
491
492
493
494
495
496
497
    p.add_instruction(pass_op{}, sum);
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
    auto r = find_match(p, m);
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output4)
{
    migraphx::program p;
Paul's avatar
Paul committed
498
499
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
500
    auto pass = p.add_instruction(pass_op{}, one);
Paul's avatar
Paul committed
501
    auto sum  = p.add_instruction(sum_op{}, pass, two);
Paul's avatar
Paul committed
502
503
504
505
506
507
508
509
510
    p.add_instruction(pass_op{}, sum);
    auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
    auto r = find_match(p, m);
    EXPECT(bool{r.result == two});
}

TEST_CASE(match_skip_output5)
{
    migraphx::program p;
Paul's avatar
Paul committed
511
512
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
513
514
515
516
517
518
519
520
521
522
523
524
525
    auto pass = p.add_instruction(pass_op{}, one);
    auto sum1 = p.add_instruction(sum_op{}, pass, two);
    auto sum2 = p.add_instruction(sum_op{}, sum1, one);
    auto sum3 = p.add_instruction(sum_op{}, sum2, two);
    p.add_instruction(pass_op{}, sum3);
    auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
    auto r = find_match(p, m);
    EXPECT(bool{r.result == p.end()});
}

TEST_CASE(match_skip_output6)
{
    migraphx::program p;
Paul's avatar
Paul committed
526
527
    auto one   = p.add_literal(1);
    auto two   = p.add_literal(2);
Paul's avatar
Paul committed
528
    auto minus = p.add_instruction(minus_op{}, two, one);
Paul's avatar
Paul committed
529
530
531
    auto sum1  = p.add_instruction(sum_op{}, minus, two);
    auto sum2  = p.add_instruction(sum_op{}, sum1, one);
    auto sum3  = p.add_instruction(sum_op{}, sum2, two);
Paul's avatar
Paul committed
532
533
534
535
536
537
538
539
540
    p.add_instruction(pass_op{}, sum3);
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
    auto r = find_match(p, m);
    EXPECT(bool{r.result == minus});
}

TEST_CASE(match_skip_output7)
{
    migraphx::program p;
Paul's avatar
Paul committed
541
542
    auto one    = p.add_literal(1);
    auto two    = p.add_literal(2);
Paul's avatar
Paul committed
543
544
    auto minus1 = p.add_instruction(minus_op{}, two, one);
    auto minus2 = p.add_instruction(minus_op{}, two, minus1);
Paul's avatar
Paul committed
545
    auto sum    = p.add_instruction(sum_op{}, one, minus2);
Paul's avatar
Paul committed
546
547
548
549
550
551
    p.add_instruction(pass_op{}, sum);
    auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("minus")));
    auto r = find_match(p, m);
    EXPECT(bool{r.result == minus1});
}

Paul's avatar
Paul committed
552
TEST_CASE(match_bind1)
Paul's avatar
Paul committed
553
{
Paul's avatar
Paul committed
554
    migraphx::program p;
Paul's avatar
Paul committed
555
556
557
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
    auto sum  = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
558
    auto pass = p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
559
    auto m    = match::name("pass")(
Paul's avatar
Paul committed
560
561
562
                 match::args(match::name("sum")(match::args(match::name("@literal").bind("one"),
                                                            match::name("@literal").bind("two")))
                                 .bind("sum")),
Paul's avatar
Paul committed
563
                 match::standard_shape())
Paul's avatar
Paul committed
564
                 .bind("pass");
Paul's avatar
Paul committed
565
566
567
568
569
570
571
572
    auto r = find_match(p, m);
    EXPECT(bool{r.instructions.at("one") == one});
    EXPECT(bool{r.instructions.at("two") == two});
    EXPECT(bool{r.instructions.at("sum") == sum});
    EXPECT(bool{r.instructions.at("pass") == pass});
    EXPECT(bool{r.result == pass});
}

Paul's avatar
Paul committed
573
574
struct match_find_sum
{
Paul's avatar
Paul committed
575
    migraphx::instruction_ref ins;
Paul's avatar
Paul committed
576
    auto matcher() const { return match::name("sum"); }
Paul's avatar
Paul committed
577

Paul's avatar
Paul committed
578
579
580
581
    void apply(migraphx::program&, const match::matcher_result& r) const
    {
        EXPECT(bool{r.result == ins});
    }
Paul's avatar
Paul committed
582
583
584
585
};

struct match_find_literal
{
Paul's avatar
Paul committed
586
    migraphx::instruction_ref ins;
Paul's avatar
Paul committed
587
    auto matcher() const { return match::name("@literal"); }
Paul's avatar
Paul committed
588

Paul's avatar
Paul committed
589
    void apply(migraphx::program&, const match::matcher_result& r) const
Paul's avatar
Paul committed
590
591
592
593
594
595
    {
        EXPECT(bool{r.result != ins});
        EXPECT(r.result->name() == "@literal");
    }
};

Paul's avatar
Paul committed
596
TEST_CASE(match_finder)
Paul's avatar
Paul committed
597
{
Paul's avatar
Paul committed
598
    migraphx::program p;
Paul's avatar
Paul committed
599
600
601
602
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
603
    match::find_matches(p, match_find_sum{sum}, match_find_literal{sum});
Paul's avatar
Paul committed
604
605
}

Paul's avatar
Paul committed
606
int main(int argc, const char* argv[]) { test::run(argc, argv); }