matcher.hpp 6.92 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#ifndef MIGRAPH_GUARD_RTGLIB_MATCHER_HPP
#define MIGRAPH_GUARD_RTGLIB_MATCHER_HPP

#include <migraph/functional.hpp>
#include <migraph/ranges.hpp>
#include <migraph/instruction.hpp>
#include <migraph/program.hpp>
#include <migraph/type_name.hpp>
#include <unordered_map>

namespace migraph {

struct matcher_context
{
Paul's avatar
Paul committed
15
    matcher_context(instruction_ref i) : last(i) {}
Paul's avatar
Paul committed
16
    std::unordered_map<std::string, instruction_ref> instructions;
Paul's avatar
Paul committed
17
18
    instruction_ref not_found() const { return last; }

Paul's avatar
Paul committed
19
    private:
Paul's avatar
Paul committed
20
    instruction_ref last;
Paul's avatar
Paul committed
21
22
};

Paul's avatar
Paul committed
23
template <class P>
Paul's avatar
Paul committed
24
25
26
27
28
29
30
31
32
33
34
35
36
struct predicate_matcher
{
    P p;

    instruction_ref match(matcher_context& ctx, instruction_ref ins) const
    {
        assert(ins != ctx.not_found());
        if(p(ins))
            return ins;
        return ctx.not_found();
    }
};

Paul's avatar
Paul committed
37
template <class F>
Paul's avatar
Paul committed
38
39
40
41
42
43
44
45
46
47
48
struct function_matcher
{
    F f;

    instruction_ref match(matcher_context& ctx, instruction_ref ins) const
    {
        assert(ins != ctx.not_found());
        return f(ctx, ins);
    }
};

Paul's avatar
Paul committed
49
template <class F>
Paul's avatar
Paul committed
50
51
52
53
54
function_matcher<F> make_function_matcher(F f)
{
    return {f};
}

Paul's avatar
Paul committed
55
template <class M>
Paul's avatar
Paul committed
56
57
58
59
60
61
62
63
64
65
auto bind_match(M m, std::string name)
{
    return make_function_matcher([=](matcher_context& ctx, instruction_ref ins) {
        auto result = m.match(ctx, ins);
        if(result != ctx.not_found())
            ctx.instructions.emplace(name, ins);
        return result;
    });
}

Paul's avatar
Paul committed
66
template <class M>
Paul's avatar
Paul committed
67
68
69
70
struct bindable_matcher
{
    M m;

Paul's avatar
Paul committed
71
    auto bind(std::string name) { return bind_match(m, name); }
Paul's avatar
Paul committed
72
73
74
75
76
77
78

    instruction_ref match(matcher_context& ctx, instruction_ref ins) const
    {
        return m.match(ctx, ins);
    }
};

Paul's avatar
Paul committed
79
template <class M>
Paul's avatar
Paul committed
80
81
82
83
84
bindable_matcher<M> make_bindable_matcher(M m)
{
    return {m};
}

Paul's avatar
Paul committed
85
template <class F>
Paul's avatar
Paul committed
86
87
88
89
90
bindable_matcher<function_matcher<F>> make_bf_matcher(F f)
{
    return {{f}};
}

Paul's avatar
Paul committed
91
template <class F>
Paul's avatar
Paul committed
92
93
94
95
96
97
98
99
100
bindable_matcher<predicate_matcher<F>> make_bp_matcher(F f)
{
    return {{f}};
}

using bool_list = std::initializer_list<bool>;

struct id_matcher
{
Paul's avatar
Paul committed
101
    instruction_ref match(matcher_context&, instruction_ref ins) const { return ins; }
Paul's avatar
Paul committed
102
103
};

Paul's avatar
Paul committed
104
template <class M>
Paul's avatar
Paul committed
105
106
107
108
struct basic_matcher
{
    M m;

Paul's avatar
Paul committed
109
    template <class... Ts>
Paul's avatar
Paul committed
110
111
112
113
114
115
    auto operator()(Ts... ms) const
    {
        // Copy m because we cant capture `this` by value
        auto mm = m;
        return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
            auto result = mm.match(ctx, ins);
Paul's avatar
Paul committed
116
            if(result != ctx.not_found())
Paul's avatar
Paul committed
117
118
119
120
121
122
123
124
125
126
127
            {
                bool matches = fold([&](auto x, auto y) {
                    return x and y.match(ctx, result) != ctx.not_found();
                })(true, ms...);
                if(matches)
                    return result;
            }
            return ctx.not_found();
        });
    }

Paul's avatar
Paul committed
128
    auto bind(std::string name) { return bind_match(m, name); }
Paul's avatar
Paul committed
129
130
131
132
133
134
135

    instruction_ref match(matcher_context& ctx, instruction_ref ins) const
    {
        return m.match(ctx, ins);
    }
};

Paul's avatar
Paul committed
136
template <class M>
Paul's avatar
Paul committed
137
138
139
140
141
basic_matcher<M> make_basic_matcher(M m)
{
    return {m};
}

Paul's avatar
Paul committed
142
template <class F>
Paul's avatar
Paul committed
143
144
145
146
147
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f)
{
    return {{f}};
}

Paul's avatar
Paul committed
148
template <class P>
Paul's avatar
Paul committed
149
150
151
152
153
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
{
    return {{p}};
}

Paul's avatar
Paul committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#define MIGRAPH_BASIC_MATCHER(name, ...)                              \
    struct name##_m                                                   \
    {                                                                 \
        instruction_ref match(__VA_ARGS__) const;                     \
    };                                                                \
    const constexpr auto name = migraph::basic_matcher<name##_m>{{}}; \
    inline instruction_ref name##_m::match(__VA_ARGS__) const

#define MIGRAPH_PRED_MATCHER(name, ...)                                                  \
    struct name##_m                                                                      \
    {                                                                                    \
        bool operator()(__VA_ARGS__) const;                                              \
    };                                                                                   \
    const constexpr auto name = migraph::basic_matcher<predicate_matcher<name##_m>>{{}}; \
    inline bool name##_m::operator()(__VA_ARGS__) const
Paul's avatar
Paul committed
169
170
171
172
173
174
175

struct matcher_result
{
    std::unordered_map<std::string, instruction_ref> instructions;
    instruction_ref result;
};

Paul's avatar
Paul committed
176
template <class M>
Paul's avatar
Paul committed
177
178
179
180
181
182
matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
{
    assert(ins != p.end());
    matcher_result result;
    matcher_context ctx{p.end()};
    result.result = m.match(ctx, ins);
Paul's avatar
Paul committed
183
    result.instructions = ctx.instructions;
Paul's avatar
Paul committed
184
    return result;
Paul's avatar
Paul committed
185
186
}

Paul's avatar
Paul committed
187
188
template <class T, class... Ts>
std::array<T, sizeof...(Ts) + 1> make_array(T x, Ts... xs)
Paul's avatar
Paul committed
189
190
191
192
{
    return {x, xs...};
}

Paul's avatar
Paul committed
193
template <class... Ts>
Paul's avatar
Paul committed
194
195
196
197
198
199
200
bool all_of_eager(Ts... xs)
{
    return make_array((xs, true)...) == make_array(static_cast<bool>(xs)...);
}

namespace matchers {

Paul's avatar
Paul committed
201
template <class... Ts>
Paul's avatar
Paul committed
202
203
204
205
206
207
208
209
210
211
212
213
auto all_of(Ts... ms)
{
    return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
        bool matches = fold([&](auto x, auto y) {
            return x and y.match(ctx, ins) != ctx.not_found();
        })(true, ms...);
        if(matches)
            return ins;
        return ctx.not_found();
    });
}

Paul's avatar
Paul committed
214
template <class... Ts>
Paul's avatar
Paul committed
215
216
217
218
219
220
221
222
223
224
225
226
auto none_of(Ts... ms)
{
    return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
        bool matches = fold([&](auto x, auto y) {
            return x and y.match(ctx, ins) == ctx.not_found();
        })(true, ms...);
        if(matches)
            return ins;
        return ctx.not_found();
    });
}

Paul's avatar
Paul committed
227
template <class... Ts>
Paul's avatar
Paul committed
228
229
230
auto any_of(Ts... ms)
{
    return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
Paul's avatar
Paul committed
231
232
        bool matches = fold(
            [&](auto x, auto y) { return x or y.match(ctx, ins) != ctx.not_found(); })(true, ms...);
Paul's avatar
Paul committed
233
234
235
236
237
238
        if(matches)
            return ins;
        return ctx.not_found();
    });
}

Paul's avatar
Paul committed
239
MIGRAPH_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
Paul's avatar
Paul committed
240
241
242

inline auto name(std::string name)
{
Paul's avatar
Paul committed
243
    return make_basic_pred_matcher([=](instruction_ref ins) { return ins->name() == name; });
Paul's avatar
Paul committed
244
245
246
247
248
249
250
251
252
253
254
255
}

inline auto arg(std::size_t i)
{
    return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
        if(i < ins->inputs().size())
            return ins->inputs()[i];
        return ctx.not_found();
    });
}

// Workaround for bugs in clang
Paul's avatar
Paul committed
256
257
258
259
template <std::size_t...>
struct args_impl_ints
{
};
Paul's avatar
Paul committed
260

Paul's avatar
Paul committed
261
template <std::size_t... Ns, class... Ms>
Paul's avatar
Paul committed
262
263
264
265
266
auto args_impl(args_impl_ints<Ns...>, Ms... ms)
{
    return matchers::all_of(arg(Ns)(ms)...);
}

Paul's avatar
Paul committed
267
template <class... Ms>
Paul's avatar
Paul committed
268
269
270
271
272
auto args(Ms... ms)
{
    return sequence_c<sizeof...(Ms)>([=](auto... is) {
        // It needs to be written as `decltype(is)::value` for gcc 5
        return args_impl(args_impl_ints<decltype(is)::value...>{}, ms...);
Paul's avatar
Paul committed
273
    });
Paul's avatar
Paul committed
274
275
276
277
278
279
280
}

} // namespace matchers

} // namespace migraph

#endif