matcher.cpp 4.96 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
#include <migraph/matcher.hpp>
#include <migraph/iterator_for.hpp>
#include <test.hpp>
#include <basic_ops.hpp>

namespace matchers = migraph::matchers;

Paul's avatar
Paul committed
8
template <class M>
Paul's avatar
Paul committed
9
10
11
migraph::matcher_result find_match(migraph::program& p, M&& m)
{
    migraph::matcher_result result;
Paul's avatar
Paul committed
12
    for(auto ins : migraph::iterator_for(p))
Paul's avatar
Paul committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    {
        result = migraph::match_instruction(p, ins, m);
        if(result.result != p.end())
            return result;
    }
    return result;
}

void match1()
{
    migraph::program p;
    auto l = p.add_literal(1);
    auto m = matchers::standard_shape();
    auto r = find_match(p, m);
    EXPECT(bool{r.result == l});
}

void match_name1()
{
    migraph::program p;
Paul's avatar
Paul committed
33
34
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
35
36
37
38
39
40
41
42
43
44
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
    auto m = matchers::name("sum");
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

void match_name2()
{
    migraph::program p;
Paul's avatar
Paul committed
45
46
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
47
48
49
50
51
52
53
54
55
56
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
    auto m = matchers::name("min");
    auto r = find_match(p, m);
    EXPECT(bool{r.result == p.end()});
}

void match_name3()
{
    migraph::program p;
Paul's avatar
Paul committed
57
58
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
59
60
61
62
63
64
65
66
67
68
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
    auto m = matchers::name("sum")(matchers::standard_shape());
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

void match_arg1()
{
    migraph::program p;
Paul's avatar
Paul committed
69
70
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
71
72
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
73
74
    auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal")),
                                   matchers::standard_shape());
Paul's avatar
Paul committed
75
76
77
78
79
80
81
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

void match_arg2()
{
    migraph::program p;
Paul's avatar
Paul committed
82
83
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
84
85
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
86
87
    auto m =
        matchers::name("sum")(matchers::arg(0)(matchers::name("sum")), matchers::standard_shape());
Paul's avatar
Paul committed
88
89
90
91
92
93
94
    auto r = find_match(p, m);
    EXPECT(bool{r.result == p.end()});
}

void match_arg3()
{
    migraph::program p;
Paul's avatar
Paul committed
95
96
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
97
98
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
99
100
    auto m = matchers::name("sum")(matchers::arg(1)(matchers::name("@literal")),
                                   matchers::standard_shape());
Paul's avatar
Paul committed
101
102
103
104
105
106
107
108
109
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

void match_arg4()
{
    migraph::program p;
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
110
    auto sum  = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
111
    auto pass = p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
112
113
    auto m =
        matchers::name("pass")(matchers::arg(0)(matchers::name("sum")), matchers::standard_shape());
Paul's avatar
Paul committed
114
115
116
117
118
119
120
    auto r = find_match(p, m);
    EXPECT(bool{r.result == pass});
}

void match_arg5()
{
    migraph::program p;
Paul's avatar
Paul committed
121
122
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
123
124
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
125
126
    auto m =
        matchers::name("pass")(matchers::arg(1)(matchers::name("sum")), matchers::standard_shape());
Paul's avatar
Paul committed
127
128
129
130
131
132
133
    auto r = find_match(p, m);
    EXPECT(bool{r.result == p.end()});
}

void match_arg6()
{
    migraph::program p;
Paul's avatar
Paul committed
134
135
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
136
137
138
139
140
141
142
143
144
145
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
    auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal")));
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

void match_arg7()
{
    migraph::program p;
Paul's avatar
Paul committed
146
147
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
148
149
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
150
151
    auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal")),
                                   matchers::arg(1)(matchers::name("@literal")));
Paul's avatar
Paul committed
152
153
154
155
156
157
158
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

void match_args1()
{
    migraph::program p;
Paul's avatar
Paul committed
159
160
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
161
162
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(pass_op{}, sum);
Paul's avatar
Paul committed
163
164
165
    auto m = matchers::name("sum")(
        matchers::args(matchers::name("@literal"), matchers::name("@literal")),
        matchers::standard_shape());
Paul's avatar
Paul committed
166
167
168
169
    auto r = find_match(p, m);
    EXPECT(bool{r.result == sum});
}

Paul's avatar
Paul committed
170
171
int main()
{
Paul's avatar
Paul committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    match1();
    match_name1();
    match_name2();
    match_name3();

    match_arg1();
    match_arg2();
    match_arg3();
    match_arg4();
    match_arg5();
    match_arg6();
    match_arg7();

    match_args1();
}