matcher.hpp 9.33 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
#ifndef MIGRAPH_GUARD_RTGLIB_MATCHER_HPP
#define MIGRAPH_GUARD_RTGLIB_MATCHER_HPP

#include <migraph/functional.hpp>
#include <migraph/ranges.hpp>
#include <migraph/instruction.hpp>
#include <migraph/program.hpp>
Paul's avatar
Paul committed
8
#include <migraph/iterator_for.hpp>
9
#include <migraph/config.hpp>
Paul's avatar
Paul committed
10
11
#include <unordered_map>

12
13
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
Paul's avatar
Paul committed
14

Paul's avatar
Paul committed
15
namespace match {
Paul's avatar
Paul committed
16

Paul's avatar
Paul committed
17
18
struct matcher_context
{
Paul's avatar
Paul committed
19
    matcher_context(instruction_ref i) : last(i) {}
Paul's avatar
Paul committed
20
    std::unordered_map<std::string, instruction_ref> instructions;
Paul's avatar
Paul committed
21
22
    instruction_ref not_found() const { return last; }

Paul's avatar
Paul committed
23
    private:
Paul's avatar
Paul committed
24
    instruction_ref last;
Paul's avatar
Paul committed
25
26
};

Paul's avatar
Paul committed
27
/// Convert a predicate function into a matcher
Paul's avatar
Paul committed
28
template <class P>
Paul's avatar
Paul committed
29
30
31
32
33
34
35
36
37
38
39
40
41
struct predicate_matcher
{
    P p;

    instruction_ref match(matcher_context& ctx, instruction_ref ins) const
    {
        assert(ins != ctx.not_found());
        if(p(ins))
            return ins;
        return ctx.not_found();
    }
};

Paul's avatar
Paul committed
42
/// Convert a function into a matcher
Paul's avatar
Paul committed
43
template <class F>
Paul's avatar
Paul committed
44
45
46
47
48
49
50
51
52
53
54
struct function_matcher
{
    F f;

    instruction_ref match(matcher_context& ctx, instruction_ref ins) const
    {
        assert(ins != ctx.not_found());
        return f(ctx, ins);
    }
};

Paul's avatar
Paul committed
55
/// Convert a function into a matcher
Paul's avatar
Paul committed
56
template <class F>
Paul's avatar
Paul committed
57
58
59
60
61
function_matcher<F> make_function_matcher(F f)
{
    return {f};
}

Paul's avatar
Paul committed
62
/// Converts a matcher to bind the instruction to name
Paul's avatar
Paul committed
63
template <class M>
Paul's avatar
Paul committed
64
65
auto bind_match(M m, std::string name)
{
Paul's avatar
Paul committed
66
67
68
69
70
71
72
    return make_function_matcher(
        [ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) {
            auto result = m.match(ctx, ins);
            if(result != ctx.not_found())
                ctx.instructions.emplace(name, ins);
            return result;
        });
Paul's avatar
Paul committed
73
74
}

Paul's avatar
Paul committed
75
/// Convert a matcher to a bindable matcher
Paul's avatar
Paul committed
76
template <class M>
Paul's avatar
Paul committed
77
78
79
80
struct bindable_matcher
{
    M m;

Paul's avatar
Paul committed
81
    auto bind(std::string name) const { return bind_match(m, std::move(name)); }
Paul's avatar
Paul committed
82
83
84
85
86
87
88

    instruction_ref match(matcher_context& ctx, instruction_ref ins) const
    {
        return m.match(ctx, ins);
    }
};

Paul's avatar
Paul committed
89
/// Create a bindable matcher
Paul's avatar
Paul committed
90
template <class M>
Paul's avatar
Paul committed
91
92
93
94
95
bindable_matcher<M> make_bindable_matcher(M m)
{
    return {m};
}

Paul's avatar
Paul committed
96
/// Create a bindable matcher from a function
Paul's avatar
Paul committed
97
template <class F>
Paul's avatar
Paul committed
98
99
100
101
102
bindable_matcher<function_matcher<F>> make_bf_matcher(F f)
{
    return {{f}};
}

Paul's avatar
Paul committed
103
/// Create a bindable matcher from a predicate function
Paul's avatar
Paul committed
104
template <class F>
Paul's avatar
Paul committed
105
106
107
108
109
110
111
112
113
bindable_matcher<predicate_matcher<F>> make_bp_matcher(F f)
{
    return {{f}};
}

using bool_list = std::initializer_list<bool>;

struct id_matcher
{
Paul's avatar
Paul committed
114
    instruction_ref match(matcher_context&, instruction_ref ins) const { return ins; }
Paul's avatar
Paul committed
115
116
};

Paul's avatar
Paul committed
117
/// The basic matcher provides the all_of composability of the matcher
Paul's avatar
Paul committed
118
template <class M>
Paul's avatar
Paul committed
119
120
121
122
struct basic_matcher
{
    M m;

Paul's avatar
Paul committed
123
    template <class... Ts>
Paul's avatar
Paul committed
124
125
126
127
128
129
    auto operator()(Ts... ms) const
    {
        // Copy m because we cant capture `this` by value
        auto mm = m;
        return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
            auto result = mm.match(ctx, ins);
Paul's avatar
Paul committed
130
            if(result != ctx.not_found())
Paul's avatar
Paul committed
131
132
133
134
135
136
137
138
139
140
141
            {
                bool matches = fold([&](auto x, auto y) {
                    return x and y.match(ctx, result) != ctx.not_found();
                })(true, ms...);
                if(matches)
                    return result;
            }
            return ctx.not_found();
        });
    }

Paul's avatar
Paul committed
142
    auto bind(std::string name) const { return bind_match(m, std::move(name)); }
Paul's avatar
Paul committed
143
144
145
146
147
148
149

    instruction_ref match(matcher_context& ctx, instruction_ref ins) const
    {
        return m.match(ctx, ins);
    }
};

Paul's avatar
Paul committed
150
/// Create a basic matcher from a matcher
Paul's avatar
Paul committed
151
template <class M>
Paul's avatar
Paul committed
152
153
154
155
156
basic_matcher<M> make_basic_matcher(M m)
{
    return {m};
}

Paul's avatar
Paul committed
157
/// Create a basic matcher from a function
Paul's avatar
Paul committed
158
template <class F>
Paul's avatar
Paul committed
159
160
161
162
163
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f)
{
    return {{f}};
}

Paul's avatar
Paul committed
164
/// Create a basic matcher from a predicate function
Paul's avatar
Paul committed
165
template <class P>
Paul's avatar
Paul committed
166
167
168
169
170
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
{
    return {{p}};
}

Paul's avatar
Paul committed
171
/// This macro takes care of the boilerplate for defining a matcher
Paul's avatar
Paul committed
172
173
174
175
176
#define MIGRAPH_BASIC_MATCHER(name, ...)                                     \
    struct name##_m                                                          \
    {                                                                        \
        instruction_ref match(__VA_ARGS__) const;                            \
    };                                                                       \
Paul's avatar
Paul committed
177
    const constexpr auto name = migraph::match::basic_matcher<name##_m>{{}}; \
Paul's avatar
Paul committed
178
179
    inline instruction_ref name##_m::match(__VA_ARGS__) const

Paul's avatar
Paul committed
180
/// This macro takes care of the boilerplate for defining a predicate matcher
Paul's avatar
Paul committed
181
182
183
184
185
186
187
#define MIGRAPH_PRED_MATCHER(name, ...)                                                 \
    struct name##_m                                                                     \
    {                                                                                   \
        bool operator()(__VA_ARGS__) const;                                             \
    };                                                                                  \
    const constexpr auto name =                                                         \
        migraph::match::basic_matcher<migraph::match::predicate_matcher<name##_m>>{{}}; \
Paul's avatar
Paul committed
188
    inline bool name##_m::operator()(__VA_ARGS__) const
Paul's avatar
Paul committed
189
190
191
192
193
194
195

struct matcher_result
{
    std::unordered_map<std::string, instruction_ref> instructions;
    instruction_ref result;
};

Paul's avatar
Paul committed
196
/// Match a single instruction
Paul's avatar
Paul committed
197
template <class M>
Paul's avatar
Paul committed
198
199
200
201
202
matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
{
    assert(ins != p.end());
    matcher_result result;
    matcher_context ctx{p.end()};
Paul's avatar
Paul committed
203
    result.result       = m.match(ctx, ins);
Paul's avatar
Paul committed
204
    result.instructions = ctx.instructions;
Paul's avatar
Paul committed
205
    return result;
Paul's avatar
Paul committed
206
207
}

Paul's avatar
Paul committed
208
/// Find matches in a program
Paul's avatar
Paul committed
209
template <class... Ms>
Paul's avatar
Paul committed
210
211
void find_matches(program& p, Ms&&... ms)
{
Paul's avatar
Paul committed
212
    for(auto ins : iterator_for(p))
Paul's avatar
Paul committed
213
214
    {
        bool match = false;
Paul's avatar
Paul committed
215
216
        each_args(
            [&](auto&& m) {
Paul's avatar
Paul committed
217
                // cppcheck-suppress knownConditionTrueFalse
Paul's avatar
Paul committed
218
219
220
221
222
                if(match)
                    return;
                auto r = match_instruction(p, ins, m.matcher());
                if(r.result == p.end())
                    return;
Paul's avatar
Paul committed
223
                m.apply(p, r);
Paul's avatar
Paul committed
224
225
226
                match = true;
            },
            ms...);
Paul's avatar
Paul committed
227
228
229
    }
}

Paul's avatar
Paul committed
230
template <class... Ts>
Paul's avatar
Paul committed
231
232
233
234
235
236
237
238
239
240
241
242
auto all_of(Ts... ms)
{
    return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
        bool matches = fold([&](auto x, auto y) {
            return x and y.match(ctx, ins) != ctx.not_found();
        })(true, ms...);
        if(matches)
            return ins;
        return ctx.not_found();
    });
}

Paul's avatar
Paul committed
243
template <class... Ts>
Paul's avatar
Paul committed
244
245
246
247
248
249
250
251
252
253
254
255
auto none_of(Ts... ms)
{
    return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
        bool matches = fold([&](auto x, auto y) {
            return x and y.match(ctx, ins) == ctx.not_found();
        })(true, ms...);
        if(matches)
            return ins;
        return ctx.not_found();
    });
}

Paul's avatar
Paul committed
256
template <class... Ts>
Paul's avatar
Paul committed
257
258
259
auto any_of(Ts... ms)
{
    return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
Paul's avatar
Paul committed
260
261
262
        bool matches = fold([&](auto x, auto y) {
            return x or y.match(ctx, ins) != ctx.not_found();
        })(false, ms...);
Paul's avatar
Paul committed
263
264
265
266
267
268
        if(matches)
            return ins;
        return ctx.not_found();
    });
}

Paul's avatar
Paul committed
269
270
MIGRAPH_PRED_MATCHER(any, instruction_ref) { return true; }
MIGRAPH_PRED_MATCHER(none, instruction_ref) { return false; }
Paul's avatar
Paul committed
271
MIGRAPH_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
Paul's avatar
Paul committed
272
273
274
275
MIGRAPH_PRED_MATCHER(broadcast_shape, instruction_ref ins)
{
    return ins->get_shape().broadcasted();
}
Paul's avatar
Paul committed
276

Paul's avatar
Add cbr  
Paul committed
277
278
279
280
281
282
283
MIGRAPH_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins)
{
    if(ins->outputs().size() == 1)
        return ins->outputs().front();
    return ctx.not_found();
}

Paul's avatar
Paul committed
284
285
286
287
288
289
290
291
292
MIGRAPH_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
{
    if(ins->outputs().size() == 1)
        return ins;
    if(ins->outputs().empty() and std::next(ins) == ctx.not_found())
        return ins;
    return ctx.not_found();
}

Paul's avatar
Paul committed
293
294
inline auto name(std::string name)
{
Paul's avatar
Paul committed
295
296
    return make_basic_pred_matcher(
        [ =, name = std::move(name) ](instruction_ref ins) { return ins->name() == name; });
Paul's avatar
Paul committed
297
298
}

299
300
301
302
303
inline auto nargs(std::size_t n)
{
    return make_basic_pred_matcher([=](instruction_ref ins) { return ins->inputs().size() == n; });
}

Paul's avatar
Paul committed
304
305
306
307
308
309
310
311
312
313
inline auto arg(std::size_t i)
{
    return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
        if(i < ins->inputs().size())
            return ins->inputs()[i];
        return ctx.not_found();
    });
}

// Workaround for bugs in clang
Paul's avatar
Paul committed
314
315
316
317
template <std::size_t...>
struct args_impl_ints
{
};
Paul's avatar
Paul committed
318

Paul's avatar
Paul committed
319
template <std::size_t... Ns, class... Ms>
Paul's avatar
Paul committed
320
321
auto args_impl(args_impl_ints<Ns...>, Ms... ms)
{
Paul's avatar
Paul committed
322
    return match::all_of(nargs(sizeof...(Ns)), arg(Ns)(ms)...);
Paul's avatar
Paul committed
323
324
}

Paul's avatar
Paul committed
325
template <class... Ms>
Paul's avatar
Paul committed
326
327
328
329
330
auto args(Ms... ms)
{
    return sequence_c<sizeof...(Ms)>([=](auto... is) {
        // It needs to be written as `decltype(is)::value` for gcc 5
        return args_impl(args_impl_ints<decltype(is)::value...>{}, ms...);
Paul's avatar
Paul committed
331
    });
Paul's avatar
Paul committed
332
333
}

Paul's avatar
Paul committed
334
inline auto either_arg(std::size_t i, std::size_t j)
Paul's avatar
Paul committed
335
336
{
    return [=](auto m1, auto m2) {
Paul's avatar
Paul committed
337
338
        return match::any_of(match::all_of(arg(i)(m1), arg(j)(m2)),
                             match::all_of(arg(j)(m1), arg(i)(m2)));
Paul's avatar
Paul committed
339
340
341
    };
}

Paul's avatar
Paul committed
342
} // namespace match
343
} // namespace MIGRAPH_INLINE_NS
Paul's avatar
Paul committed
344
345
346
} // namespace migraph

#endif